MLIR

Multi-Level IR Compiler Framework

'xegpu' Dialect

The XeGPU dialect that models Intel GPU’s ISA The XeGPU dialect models Intel Xe ISA semantics but works at vector and TensorDesc data type. It provides 1:1 mappings to match Xe instructions like DPAS and 2D block load. The matrix size being processed at this level exactly matches the hardware instructions or the intrinsic supported by the lower-level GPU compiler.

Operations 

source

xegpu.alloc_nbarrier (xegpu::AllocNbarrierOp) 

It allocates a set of named barriers.

Syntax:

operation ::= `xegpu.alloc_nbarrier` $nbarrier_num attr-dict

AllocNbarrier is to create a set of named barriers as specified by nbarrier_num. Named barriers are workgroup level resources, and are shared by all threads in the workgroup. For example, there are up to 32 barriers (range 0-31) for each XeCore on PVC. A typical use case is that a workgroup is partitioned into N subgroups of threads (N <= 32), and each subgroup coordinating their work with a separate barrier with id range from 0 to N respectively.

Attributes: 

AttributeMLIR TypeDescription
nbarrier_num::mlir::IntegerAttr64-bit signless integer attribute

xegpu.atomic_rmw (xegpu::AtomicRMWOp) 

Atomic ready-modify-write operation on the TensorDesc.

Syntax:

operation ::= `xegpu.atomic_rmw` $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
              type($tensorDesc) `,` type($mask) `,` type($value) `->` type($result)

The xegpu.atomic_rmw operation provides a way to perform a read-modify-write operation on the region described by the TensorDesc free from data races. The kind enumeration specifies the modification to be performed, The mask operand has the same shape with TensorDesc, and is used to enable or disable specific data points of the TensorDesc. The value operand represents the new value to be applied during the modification.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kind::mlir::arith::AtomicRMWKindAttr
allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14

Enum cases:

  • addf (addf)
  • addi (addi)
  • assign (assign)
  • maximumf (maximumf)
  • maxs (maxs)
  • maxu (maxu)
  • minimumf (minimumf)
  • mins (mins)
  • minu (minu)
  • mulf (mulf)
  • muli (muli)
  • ori (ori)
  • andi (andi)
  • maxnumf (maxnumf)
  • minnumf (minnumf)

Operands: 

OperandDescription
tensorDescTensorDesc describing regions of interested data.
maskvector of 1-bit signless integer values of ranks 1/2 or 1-bit signless integer
valuevector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type

Results: 

ResultDescription
resultvector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type

xegpu.create_nd_tdesc (xegpu::CreateNdDescOp) 

Create nd-tensor descriptor operation

Syntax:

operation ::= `xegpu.create_nd_tdesc` $source ``
              custom<DynamicIndexList>($offsets, $const_offsets)
              (`,` custom<DynamicIndexList>($shape, $const_shape)^
              `,` custom<DynamicIndexList>($strides, $const_strides))?
              attr-dict `:` type($source) `->` qualified(type($TensorDesc))

The “create_nd_tdesc” operation creates a TensorDescType which represents a sub-view of a 2D memory region (It can be extended to support n-D memory region if needed in future). Elements in the subview continuous in each dimension. It encodes the following important information for supporting Intel hardware features:

  • source: an object representing (starting address/pointer of) a 2D memory region. It can be either a 2D memref object, or simply a pointer represented by uint64_t type. for the later case, the shape and layout information of the 2D memory region should be explicitly passed via shape and strides parameters.
  • offsets: two index values represents offsets from the “source” at the each dimension at which the subview of the target memory will be created. It is encoded via two variables, including “offsets” and “const_offsets”, such that it can accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
  • shape: the shape information of the memory region pointed by the “source”. It is typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the shape information has to be explicitly passed via the “shape” and “const_shape” arguments.
  • strides: the strides of the memory region pointed by the “source”. Similar to shape, it is typically encoded via the MemRefType of the source too. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the strides information has to be explicitly passed via the “strides” and “const_strides” argument.

Example 1 (suppose the tensor shape inferred by the compiler is 8x16): %0 = memref.alloc() : memref<1024x1024xf32> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>

Example 2 (suppose the tensor shape inferred by the compiler is 8x16): %0 = memref.alloc(%h, %w) : memref<?x?xf32> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>

Example 3 (suppose the tensor shape inferred by the compiler is 8x16): %0 = … : ui64 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OffsetSizeAndStrideOpInterface, ViewLikeOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
const_shape::mlir::DenseI64ArrayAttri64 dense array attribute
const_strides::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
source1D/2D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsvariadic of index
shapevariadic of index
stridesvariadic of index

Results: 

ResultDescription
TensorDescTensorDesc describing regions of interested data.

xegpu.create_tdesc (xegpu::CreateDescOp) 

Create scattered tensor descriptors (TensorDesc).

Syntax:

operation ::= `xegpu.create_tdesc` $source
              custom<DynamicIndexList>($offsets, $const_offsets)
              attr-dict `:`  type($source) `->` qualified(type($TensorDesc))

“create_tdesc” is similar to “create_nd_tdesc” in terms that it creates a Tensor Descriptor (TensorDescType) for a memory region. While “create_nd_tdesc” is for creating continuous subviews, “create_tdesc” is for creating non-continuous (scattered) subviews, allowing each work-item in a subgroup specifying their own offset. It accepts the following parameters:

  • source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
  • offsets: a array containing offsets of each access point. Its size is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, implying each element in the array corresponds to a work-item (SIMT lane) in the subgroup.
  • chunk_size: [optional attribute] indicates number of continious elements accessed for each offset, default is 1.

Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]

%a = memref.alloc() : memref<1024xf32>
%1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>

Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]

%0 = memref.alloc() : memref<1024xf32>
%1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>

Example 3. It is similar to Example 2, but there is some overlaps among workitems. It accesses: a[0:7], a[4:11], a[8:15], a[12:19]

%0 = memref.alloc() : memref<1024xf32>
%1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ViewLikeOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
chunk_size::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
source1D/2D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsvariadic of index

Results: 

ResultDescription
TensorDescTensorDesc describing regions of interested data.

xegpu.dpas (xegpu::DpasOp) 

It performs mma computation

Syntax:

operation ::= `xegpu.dpas` $lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)?  `->` type($result)

DPAS performs matrix multiplication on matrix A of mxk size, B of kxn size, and accumulate on matrix C of mxn to the same size matrix , m=8, n=16 and k=8 * 32/bit_width_of_elem_type. So for fp16 data type, the matrices are A: vector<8x16xf16>, B: vector<16x16xf16>, and C/D: vector<8x16xf32>. Besides the matrix size requirements, DPAS also requires A and B to be loaded with the required data layout. Specially, VNNI layout is required for B operand. It is achieved via setting vnni_axis = 0 of the corresponding load_nd operator. To keep both operands as 3D vector, operand A is loaded via setting vnni_axis = 1 without impacting the physical layouts change in register. Due to the VNNI transformation, A and B operands are represented as 3D vector, with the last dimension representing the VNNI factor, which is computed as 32/bit_width_of_elem_type. Therefore, A: vector<8x16xf16> is represented as A: vector<8x8x2xf16>, and B: vector<16x16xf16> is represented as B: vector<8x16x2xf16>.

Note: on PVC, the hardware can perform load with VNNI transformation when data
      element type is 16-bit or lower precision, taking 2 or 4 elements from
      the first dimension and inserted into the newly added innermost dimension.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
lhsvector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2/3
rhsvector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2/3
accvector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2

Results: 

ResultDescription
resultvector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2

xegpu.fence (xegpu::FenceOp) 

It synchronizes memory accesses.

Syntax:

operation ::= `xegpu.fence` `memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict

It synchronizes the memory access between write and following read or write. 1. Memory_kind describes the memory kind. “global” means the global memory, “slm” means the share local memory. 2. Fence_scope describes the scope of fence. “Workgroup” means that the scope would be within each workgroup. “GPU” means the scope would be across workgroups within the GPU.

Attributes: 

AttributeMLIR TypeDescription
memory_kind::mlir::xegpu::MemoryScopeAttr
Describe the location of data described by a `TensorDesc`: Global device memory (`Global`) or Shared local memory (`SLM`).

Enum cases:

  • global (Global)
  • slm (SLM)
fence_scope::mlir::xegpu::FenceScopeAttr
Describes the scope of fence. "workgroup" means that the scope is within each work group. "gpu" means the scope is across work groups within the gpu.

Enum cases:

  • workgroup (Workgroup)
  • gpu (GPU)

xegpu.init_nbarrier (xegpu::InitNbarrierOp) 

It assigns a named barrier to the current thread.

Syntax:

operation ::= `xegpu.init_nbarrier` $nbarrier_id `,` $participant_thread_num attr-dict `:`
              type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))

InitNbarrierOp assigns the named barrier with the specified barrier ID (0~31) to the current thread. Multiple threads may bind to the same named barrier, and the participant_thread_num specifies the total number of threads associated with the nbarrier. It returns an object of NbarrierType representing the barrier

Operands: 

OperandDescription
nbarrier_id8-bit signless integer
participant_thread_num8-bit signless integer

Results: 

ResultDescription
result!xegpu.nbarrier a custom XeGPU type representing a barrier.

xegpu.load (xegpu::LoadGatherOp) 

Load a set of scattered data points from memory.

Syntax:

operation ::= `xegpu.load` $TensorDesc `,` $mask prop-dict attr-dict
              `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)

It (aka. load) load data per each work-item. The output describes the data being loaded at the subgroup level, so its size is consistent with the number of work-items in a subgroup. When chunk_size_per_lane attribute is larger than 1 in TensorDesc, the output vector will be 2D vector, with dim-1 correspoding to the chunk size.

The mask operand masks out memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.

Example:

  %2 = xegpu.load %1, %0 {transpose = [1, 0],
                          l1_hint = #xegpu.cache_hint<cached>,
                          l2_hint = #xegpu.cache_hint<uncached>,
                          l3_hint = #xegpu.cache_hint<uncached>}
        : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
          -> vector<16xf32>

Attributes: 

AttributeMLIR TypeDescription
transpose::mlir::DenseI64ArrayAttri64 dense array attribute
l1_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l2_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l3_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
maskvector of 1-bit signless integer values of ranks 1/2 or 1-bit signless integer

Results: 

ResultDescription
valuevector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type

xegpu.load_nd (xegpu::LoadNdOp) 

Loads a n-D block from memory (represented by TensorDesc)to registers (represented by vector)

Syntax:

operation ::= `xegpu.load_nd` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)

LoadNdOp essentially mimics the hardware block read instruction to read a block of data from memory to register. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. VNNI transformation is an hardware feature for Intel GPU, which is used to do data packing during the load for B operand of matrix operation, if the bit width of the data type is less then 32 bits, e.g., fp16. And transpose is another Intel hardware feature, which will do transpose operation when loading the data if the bit width of the data type is fp32 or fp64. It implies that vnni and transpose cannot exit at the same time.

Example:

  xegpu.load_nd %1 {transpose = [1, 0],
                    l1_hint = #xegpu.cache_hint<cached>,
                    l2_hint = #xegpu.cache_hint<uncached>,
                    l3_hint = #xegpu.cache_hint<streaming>}
          : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>

Attributes: 

AttributeMLIR TypeDescription
vnni_axis::mlir::IntegerAttr64-bit signless integer attribute
transpose::mlir::DenseI64ArrayAttri64 dense array attribute
l1_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l2_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l3_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.

Results: 

ResultDescription
valuevector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type

xegpu.nbarrier_arrive (xegpu::NbarrierArriveOp) 

It signals the arrival at the named barrier.

Syntax:

operation ::= `xegpu.nbarrier_arrive` $nbarrier attr-dict `:` qualified(type($nbarrier))

NbarrierArriveOp signals the hardware (or other threads) that the current thread has produced its data for the consumer threads. When the hardware signalled by participant_thread_num threads for the named barrier, it will notify the threads waiting for the named barrier to continue their work.

Operands: 

OperandDescription
nbarrier!xegpu.nbarrier a custom XeGPU type representing a barrier.

xegpu.nbarrier_wait (xegpu::NbarrierWaitOp) 

It waits for a named barrier.

Syntax:

operation ::= `xegpu.nbarrier_wait` $nbarrier attr-dict `:` qualified(type($nbarrier))

NbarrierWaitOp signals the hardware which named barrier the current thread is waiting for, such that it can get notified when the named barrier is completed.

Operands: 

OperandDescription
nbarrier!xegpu.nbarrier a custom XeGPU type representing a barrier.

xegpu.prefetch (xegpu::PrefetchOp) 

Prefetches a set of scattered data points to cache

Syntax:

operation ::= `xegpu.prefetch` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))

It issues instructions to prefetch a set of scattered data points from memory to each level of the cache based on their cache policy. As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead.

Example:

  xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>}
    : !xegpu.tensor_desc<16xf16>

Attributes: 

AttributeMLIR TypeDescription
l1_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l2_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l3_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.

xegpu.prefetch_nd (xegpu::PrefetchNdOp) 

Prefetches a n-D block to cache

Syntax:

operation ::= `xegpu.prefetch_nd` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))

It issues an instruction to prefetch a block of data from continuous memory regions to each level of the cache based on their cache policy.

Example:

  xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                            l2_hint = #xegpu.cache_hint<cached>,
                            l3_hint = #xegpu.cache_hint<cached>}
    : !xegpu.tensor_desc<8x16xf16>

Attributes: 

AttributeMLIR TypeDescription
l1_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l2_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l3_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.

xegpu.store (xegpu::StoreScatterOp) 

Store data to scattered memory locations.

Syntax:

operation ::= `xegpu.store` $value `,` $TensorDesc `,` $mask prop-dict attr-dict
              `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)

It (aka. store) stores data to scattered memory locations. It has similar semantic to load_gather.

Example:

  %3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                               l2_hint = #xegpu.cache_hint<write_back>,
                               l3_hint = #xegpu.cache_hint<write_through>}
        : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>

Attributes: 

AttributeMLIR TypeDescription
l1_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l2_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l3_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Operands: 

OperandDescription
valuevector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type
TensorDescTensorDesc describing regions of interested data.
maskvector of 1-bit signless integer values of ranks 1/2 or 1-bit signless integer

xegpu.store_nd (xegpu::StoreNdOp) 

Stores a n-D block register region back to memory, currently only supports 2D

Syntax:

operation ::= `xegpu.store_nd` $value `,` $TensorDesc prop-dict attr-dict
              `:` type($value) `,` qualified(type($TensorDesc))

StoreNdOp essentially mimics the hardware block write instruction io write a block of data from register into the memory region as described by the TensorDesc. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked.

Example:

  xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                         l2_hint = #xegpu.cache_hint<write_back>,
                         l3_hint = #xegpu.cache_hint<write_through>}
                         : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>

Attributes: 

AttributeMLIR TypeDescription
l1_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l2_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)
l3_hint::mlir::xegpu::CachePolicyAttr
Describe the cache settings for prefetch/load/store operators

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Operands: 

OperandDescription
valuevector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type
TensorDescTensorDesc describing regions of interested data.

xegpu.update_nd_offset (xegpu::UpdateNdOffsetOp) 

It updates the offsets for the TensorDesc.

Syntax:

operation ::= `xegpu.update_nd_offset` $TensorDesc `,`
              custom<DynamicIndexList>($offsets, $const_offsets)
              attr-dict `:` qualified(type($result))

The op updates the offset of the given TensorDesc. The offsets are relative offset to the current position in the number of elements. It will result in a same type TensorDesc as the input.

example:

  %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
offsetsvariadic of index

Results: 

ResultDescription
resultTensorDesc describing regions of interested data.

xegpu.update_offset (xegpu::UpdateOffsetOp) 

It updates the offsets for the given tensor descriptor

Syntax:

operation ::= `xegpu.update_offset` $TensorDesc `,`
              custom<DynamicIndexList>($offsets, $const_offsets)
              attr-dict `:` qualified(type($TensorDesc))

It behaves similar to update_nd_offset in terms that it updates offset of a TensorDesc, and the offsets are relative offset to the current position in the number of elements. However, update_nd_offset is to update the start point of a 2D block, so its offset constains two elements representing the shift in each dimension. update_offset is to update the offset per work-item, so its offsets contains values representing shifts for each work-item.

Example:
```
  %2 = xegpu.update_offset %1, [32, 32, 32, 32]
        : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
```

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
offsetsvariadic of index

Results: 

ResultDescription
resultTensorDesc describing regions of interested data.

Attributes 

CachePolicyAttr 

Describe the cache settings for prefetch/load/store operators

Syntax:

#xegpu.cache_hint<
  ::mlir::xegpu::CachePolicy   # value
>

Enum cases:

  • cached (CACHED)
  • uncached (UNCACHED)
  • streaming (STREAMING)
  • read_invalidate (READ_INVALIDATE)
  • write_back (WRITE_BACK)
  • write_through (WRITE_THROUGH)

Parameters: 

ParameterC++ typeDescription
value::mlir::xegpu::CachePolicyan enum of type CachePolicy

FenceScopeAttr 

Describes the scope of fence. “workgroup” means that the scope is within each work group. “gpu” means the scope is across work groups within the gpu.

Syntax:

#xegpu.fence_scope<
  ::mlir::xegpu::FenceScope   # value
>

Enum cases:

  • workgroup (Workgroup)
  • gpu (GPU)

Parameters: 

ParameterC++ typeDescription
value::mlir::xegpu::FenceScopean enum of type FenceScope

MemoryScopeAttr 

Describe the location of data described by a TensorDesc: Global device memory (Global) or Shared local memory (SLM).

Syntax:

#xegpu.memory_scope<
  ::mlir::xegpu::MemoryScope   # value
>

Enum cases:

  • global (Global)
  • slm (SLM)

Parameters: 

ParameterC++ typeDescription
value::mlir::xegpu::MemoryScopean enum of type MemoryScope

TensorDescAttr 

a composite attribute for TensorDescType

Syntax:

#xegpu.tdesc_attr<
  MemoryScopeAttr,   # memory_scope
  IntegerAttr,   # array_length
  BoolAttr,   # boundary_check
  BoolAttr   # scattered
>

TensorDescAttr (or tdesc_attr) is a composite attribute defined for TensorDescType for describing following properties of a TensorDesc. 1. memory_scope: It describes where the data block described by the TensorDesc is located, Global device memory or Shared local memory. It is default to Global. 2. array_length: It describes how many horizontally consecutive blocks will be loaded by a hardware load instruction. If the TensorDesc shape is 8x16, with array_length = 2. The loaded block shape will be acctually 8x32. Its default value is 1. 3. boundary_check: It is used to indicates the hardware whether to do out-of-boundary check. The default value is true. 4. scattered: It is used to differenciate TensorDescs created from create_nd_tdesc vs from create_tdesc.

Parameters: 

ParameterC++ typeDescription
memory_scopeMemoryScopeAttr
array_lengthIntegerAttr1
boundary_checkBoolAttrtrue
scatteredBoolAttrfalse

Types 

NbarrierType 

!xegpu.nbarrier a custom XeGPU type representing a barrier.

Syntax: !xegpu.nbarrier

TensorDescType 

TensorDesc describing regions of interested data.

TensorDesc is a type designed to describe regions of the interested data as well as some features that are unique to Intel hardware. Different with the builtin tensor type in MLIR, it essentially only contains the meta data, and doesn’t hold the data by itself. It is designed to mainly support 2D block load/store and DPAS (matrix multiplication instruction) on Intel GPU. It encodes the following information:

  • shape: the sizes/shape of the intereted data block, e.g., 8x16 means 8 rows and each row contains 16 contiguous data element. The rows could be either contiguous or not, depends on whether the encoding attribute is set or not.
  • element_type: the data type of the data element, e.g., f16, f32.

Similar to the builtin tensor, it also provides an optinal attribute to encoding the following information via the TensorDescAttr object:

  • memory_scope (xegpu::MemoryScope): [optional] where the data is located, global memory or shared memory. It is default to Global.
  • array_length (int): [optional] The number of contiguous blocks with size as shape, that will be loaded by block load at a time. It is default to 1.
  • boundary_check (bool): [optional] indicates whether the operation detects the boundary and pads with zero for out-of-boundary access. It is default to do boundary check.

Syntax:

TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, memory_scope = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?

Examples:

// A block TensorDesc with 8x16 i32 elements
xegpu.tensor_desc<8x16xi32>

// A block TensorDesc with 8x16 f32 elements
xegpu.tensor_desc<8x16xf32>

// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>

Parameters: 

ParameterC++ typeDescription
shape::llvm::ArrayRef<int64_t>
elementTypemlir::Type
encodingmlir::Attribute