MLIR

Multi-Level IR Compiler Framework

'xegpu' Dialect

The XeGPU dialect that models Intel GPU’s ISA

The XeGPU dialect closely models a subset of the Xe GPU’s ISA, providing an abstraction to support high-performance GEMM code generation. It serves as a bridge dialect in the MLIR gradual lowering process, working with MLIR memref and vector types, and complements the Arith, Math, Vector, and Memref dialects. XeGPU operations are introduced for special Xe instructions not modeled by the LLVM/SPIR-V dialect, such as DPAS and 2D block load and store.

It supports a tile-based programming model, decomposing the GEMM kernel into large predefined tile sizes at the subgroup and workgroup levels. XeGPU allows the high-level GEMM algorithm to be easily expressed. Underneath, it uses target-specific recipes and hardware features to achieve optimal performance on specific hardware. By decomposing GEMM at submatrix granularity and mapping it to registers, it naturally supports optimizations like fusing with neighboring operations.

Operations 

source

xegpu.alloc_nbarrier (xegpu::AllocNbarrierOp) 

It allocates a set of named barriers.

Syntax:

operation ::= `xegpu.alloc_nbarrier` $nbarrier_num attr-dict

AllocNbarrier is to create a set of named barriers as specified by nbarrier_num. Named barriers are workgroup level resources, and are shared by all threads in the workgroup. For example, there are up to 32 barriers (range 0-31) for each XeCore on PVC. A typical use case is that a workgroup is partitioned into N subgroups of threads (N <= 32), and each subgroup coordinating their work with a separate barrier with id range from 0 to N respectively.

Attributes: 

AttributeMLIR TypeDescription
nbarrier_num::mlir::IntegerAttr64-bit signless integer attribute

xegpu.atomic_rmw (xegpu::AtomicRMWOp) 

Atomic read-modify-write operation on the TensorDesc.

Syntax:

operation ::= `xegpu.atomic_rmw` $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
              qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)

The xegpu.atomic_rmw operation provides a way to perform a read-modify-write operation on the region described by the TensorDesc free from data races. The kind enumeration specifies the modification to be performed, The mask operand has the same shape with TensorDesc, and is used to enable or disable specific data points of the TensorDesc. The value operand represents the new value to be applied during the modification.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, MemoryEffectOpInterface (MemoryEffectOpInterface), NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource, MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}, MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kind::mlir::arith::AtomicRMWKindAttrallowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15

Operands: 

OperandDescription
tensorDescTensorDesc describing regions of interested data.
maskfixed-length vector of 1-bit signless integer values
valuefixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values

Results: 

ResultDescription
resultfixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values

xegpu.convert_layout (xegpu::ConvertLayoutOp) 

Convert the layout of the input operand

Syntax:

operation ::= `xegpu.convert_layout` $source prop-dict attr-dict `:` type($source)

convert_layout redistribute data across subgroups and/or work-items from the input_layout to the target_layout. Both input_layout and target_layout must correspond to the same programming scope, such as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once the IR is lowered to WI level because that is the end result of all distributions.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
input_layout::mlir::xegpu::DistributeLayoutAttr
DistributeLayoutAttr instance
Common trait for all XeGPU layouts.
target_layout::mlir::xegpu::DistributeLayoutAttr
DistributeLayoutAttr instance
Common trait for all XeGPU layouts.

Operands: 

OperandDescription
sourcevector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4/5/6

Results: 

ResultDescription
resultvector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4/5/6

xegpu.create_mem_desc (xegpu::CreateMemDescOp) 

Create a memory descriptor.

Syntax:

operation ::= `xegpu.create_mem_desc` $source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))

Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu specific memory layout. The resulting memory descriptor has to have the same size as the underlying shared local memory.

Arguments:

  • source : a 1D statically shaped memref with element type i8, representing the raw SLM buffer. Results:
  • mem_desc : the memory descriptor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
sourcestatically shaped memref of 8-bit signless integer values for shared memory

Results: 

ResultDescription
mem_descMemDesc describing the data in SLM

xegpu.create_nd_tdesc (xegpu::CreateNdDescOp) 

Create nd-tensor descriptor operation

Syntax:

operation ::= `xegpu.create_nd_tdesc` $source ``
              custom<OptionalDynamicIndexList>($offsets, $const_offsets)
              (`,` `shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
              `,` `strides``:` custom<DynamicIndexList>($strides, $const_strides))?
              attr-dict `:` type($source) `->` qualified(type($TensorDesc))

The “create_nd_tdesc” operation creates a TensorDescType which represents a sub-view of a 1D/2D memory region inside the one or two innermost dimensions of the source. (It can be extended to support n-D memory region if needed in future). Elements in the subview continuous in each dimension. It encodes the following important information for supporting Intel hardware features:

Arguments:

  • source: an object representing (starting address/pointer of) a memory region. It can be either a memref object, or simply a pointer represented by uint64_t type. For the case of dynamic memrefs or pointer, the shape and layout information of the memory region should be explicitly passed via shape and strides parameters.

  • offsets: index values represents offsets from the “source” at the each dimension at which the subview of the target memory will be created. It is encoded via “offsets” and “const_offsets”, such that it can accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).

  • shape: the shape information of the memory region pointed by the “source”. It is typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the shape information has to be explicitly passed via the “shape” and “const_shape” arguments.

  • strides: the strides of the memory region pointed by the “source”. Similar to shape, it is typically encoded via the MemRefType of the source too. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the strides information has to be explicitly passed via the “strides” and “const_strides” argument.

Results:

  • res: nd tensor descriptor

Example 1 (suppose the tensor shape inferred by the compiler is 8x16):

%0 = memref.alloc() : memref<1024x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>

Example 2 (suppose the tensor shape inferred by the compiler is 8x16):

%0 = memref.alloc(%h, %w) : memref<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>

Example 3 (suppose the tensor shape inferred by the compiler is 8x16):

%0 = ... : ui64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>

Traits: AlwaysSpeculatableImplTrait, AttrSizedOperandSegments

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ViewLikeOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
const_shape::mlir::DenseI64ArrayAttri64 dense array attribute
const_strides::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
sourcenon-0-ranked.memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsvariadic of index
shapevariadic of index
stridesvariadic of index

Results: 

ResultDescription
TensorDescTensorDesc describing regions of interested data.

xegpu.create_tdesc (xegpu::CreateDescOp) 

Create scattered tensor descriptors (TensorDesc).

Syntax:

operation ::= `xegpu.create_tdesc` $source `,` $offsets attr-dict `:`  type($source) `,` type($offsets) `->` qualified(type($TensorDesc))

“create_tdesc” is similar to “create_nd_tdesc” in terms that it creates a Tensor Descriptor (TensorDescType) for a memory region. While “create_nd_tdesc” is for creating continuous subviews, “create_tdesc” is for creating non-continuous (scattered) subviews, allowing each work-item in a subgroup specifying their own offset. It accepts the following parameters:

Arguments:

  • source: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened memory object.
  • offsets: a vector containing offsets of each access point. Its size is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, implying each element in the vector corresponds to a work-item (SIMT lane) in the subgroup.

Results:

  • res: scattered tensor descriptor

The first dimension of the result TensorDesc corresponds to work-items, so it should match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1.

Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]

%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>

Example 2: It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]

%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
      -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>

Example 3: It is similar to Example 2, but there is some overlaps among workitems. It accesses: a[0:7], a[4:11], a[8:15], a[12:19]

%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
      -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ViewLikeOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
source1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsfixed-length vector of index values

Results: 

ResultDescription
TensorDescTensorDesc describing regions of interested data.

xegpu.dpas (xegpu::DpasOp) 

It performs mma computation

Syntax:

operation ::= `xegpu.dpas` $lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)?  `->` type($result)

DPAS performs matrix multiplication on matrix A of mxk size, B of kxn size, and accumulate on matrix C of mxn to the same size matrix , m=8, n=16 and k=8 * 32/bit_width_of_elem_type. So for fp16 data type, the matrices are A: vector<8x16xf16>, B: vector<16x16xf16>, and C/D: vector<8x16xf32>. Besides the matrix size requirements, DPAS also requires A and B to be loaded with the required data layout. Specially, VNNI layout is required for B operand. It is achieved via adding packed attribute to the load_nd operator. Due to the VNNI transformation, B operands can be represented as a 3D vector, with the last dimension representing the VNNI factor, which is computed as 32/bit_width_of_elem_type. Thus, B: vector<16x16xf16> can be represented as B: vector<8x16x2xf16>.

In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
(https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
for more details about the fragment distribution.

Note: on PVC, the hardware can perform load with VNNI transformation when data
      element type is 16-bit or lower precision, taking 2 or 4 elements from
      the first dimension and inserted into the newly added innermost dimension.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
lhsfixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3
rhsfixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3
accfixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2

Results: 

ResultDescription
resultfixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2

xegpu.fence (xegpu::FenceOp) 

It synchronizes memory accesses.

Syntax:

operation ::= `xegpu.fence` `memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict

It synchronizes the memory access between write and following read or write. 1. Memory_kind describes the memory kind. “global” means the global memory, “slm” means the share local memory. 2. Fence_scope describes the scope of fence. “Workgroup” means that the scope would be within each workgroup. “GPU” means the scope would be across workgroups within the GPU.

Attributes: 

AttributeMLIR TypeDescription
memory_kind::mlir::xegpu::MemorySpaceAttrDescribe the location of data described by a `TensorDesc`: Global device memory (`Global`) or Shared local memory (`SLM`).
fence_scope::mlir::xegpu::FenceScopeAttrDescribes the scope of fence. "workgroup" means that the scope is within each work group. "gpu" means the scope is across work groups within the gpu.

xegpu.init_nbarrier (xegpu::InitNbarrierOp) 

It assigns a named barrier to the current thread.

Syntax:

operation ::= `xegpu.init_nbarrier` $nbarrier_id `,` $participant_thread_num attr-dict `:`
              type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))

InitNbarrierOp assigns the named barrier with the specified barrier ID (0~31) to the current thread. Multiple threads may bind to the same named barrier, and the participant_thread_num specifies the total number of threads associated with the nbarrier. It returns an object of NbarrierType representing the barrier

Operands: 

OperandDescription
nbarrier_id8-bit signless integer
participant_thread_num8-bit signless integer

Results: 

ResultDescription
result!xegpu.nbarrier a custom XeGPU type representing a barrier.

xegpu.load (xegpu::LoadGatherOp) 

Load a set of scattered data points from memory.

Syntax:

operation ::= `xegpu.load` $source
              (`[` $offsets^ `]`)? `,`
              $mask prop-dict
              attr-dict `:` type(operands) `->` type($value)

It (aka. load) load data per each work-item. The output describes the data being loaded at the subgroup level, so its size is consistent with the number of work-items in a subgroup. When the chunk size is larger than 2, the output vector is a 2D vector, with dim-0 correspoding to work-items, and dim-1 corresponding to the chunk size loaded by each work-item. The mask operand masks out memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.

In SIMT mode, the result is a 1D vector that represents the data to be loaded by each work-item. If size is not 1, size should be equal to the chunk size,

Arguments:

  • source: represents the memory region to be loaded from, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. tensor_desc cannot be used in SIMT mode.
  • offsets: represents offsets from source. required if source in not a TensorDescType. offsets is a vector of index type and vector length is either the subgroup size or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
  • mask: is a vector of i1 type, which is used to mask out the memory access. mask is a vector of size equal to the subgroup size, or 1 in SIMT mode. scalar mask is also valid for SIMT mode.
  • chunk_size: (optional) represents contiguous number of elements to load from per work item.
  • l1_hint, l2_hint, l3_hint: are optional cache hints for each level of cache.

Results:

  • res: represents loaded data

Example 1:

  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
                           l2_hint = #xegpu.cache_hint<uncached>,
                           l3_hint = #xegpu.cache_hint<uncached>}>
        : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
          vector<16xi1> -> vector<16xf32>

Example 2:

  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
                           l2_hint = #xegpu.cache_hint<uncached>,
                           l3_hint = #xegpu.cache_hint<uncached>}>
        : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
          vector<16xi1> -> vector<16x8xf32>

Example 3: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines “create scattered TensorTdesc” and “load with scattered TensorTdesc”. The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc for the restriction of memref.

  %a = memref.alloc() : memref<1024xf32>
  %offsets = vector.step : vector<16xindex>
  %mask = vector.constant_mask [16]: vector<16xi1>
  %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>}
    : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>

Example 4 (SIMT mode): SIMT mode only accepts the offsets variant. chunk_size can be inferred from result type. In this example, chunk_size is 8.

  %2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>,
                           l2_hint = #xegpu.cache_hint<uncached>,
                           l3_hint = #xegpu.cache_hint<uncached>}>
        : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
chunk_size::mlir::IntegerAttr64-bit signless integer attribute
l1_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l2_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l3_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators

Operands: 

OperandDescription
sourceTensorDesc describing regions of interested data. or 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsfixed-length vector of index values or index
maskfixed-length vector of 1-bit signless integer values or 1-bit signless integer

Results: 

ResultDescription
valuefixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type

xegpu.load_matrix (xegpu::LoadMatrixOp) 

Syntax:

operation ::= `xegpu.load_matrix` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
              prop-dict attr-dict `` `:` type(operands) `->` type(results)

This operation loads a 2D block of data from shared local memory (SLM) as specified by the provided 2D mem_desc. Only 2D memory descriptors are supported; use the subview operation to obtain a compatible 2D mem_desc from a higher-rank descriptor if needed.

Arguments:

  • mem_desc: the memory descriptor identifying the SLM region.
  • offsets: the coordinates within the matrix to read from.
  • layout: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. Results:
  • res: the matrix elements loaded from SLM.

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
layout::mlir::xegpu::DistributeLayoutAttr
DistributeLayoutAttr instance
Common trait for all XeGPU layouts.

Operands: 

OperandDescription
mem_descMemDesc describing the data in SLM
offsetsvariadic of index

Results: 

ResultDescription
resfixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values

xegpu.load_nd (xegpu::LoadNdOp) 

Loads a n-D block from memory (represented by TensorDesc)to registers (represented by vector)

Syntax:

operation ::= `xegpu.load_nd` $TensorDesc ``
              custom<OptionalDynamicIndexList>($offsets, $const_offsets)
              prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)

LoadNdOp essentially mimics the hardware block read instruction to read a block of data from memory to register. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. VNNI transformation is an hardware feature for Intel GPU, which is used to do data packing during the load for B operand of matrix operation, if the bit width of the data type is less then 32 bits, e.g., fp16. And transpose is another Intel hardware feature, which will do transpose operation when loading the data if the bit width of the data type is fp32 or fp64. It implies that vnni and transpose cannot exit at the same time. It is only available to 1D or 2D blocked tensor_desc.

In SIMT mode, result vector represents the data to be loaded by each work-item.

Example 1:

  xegpu.load_nd %1 {transpose = [1, 0],
                    l1_hint = #xegpu.cache_hint<cached>,
                    l2_hint = #xegpu.cache_hint<uncached>,
                    l3_hint = #xegpu.cache_hint<streaming>}
          : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>

Example 2 (SIMT mode):

  xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
                    l2_hint = #xegpu.cache_hint<uncached>}>
    : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
packed::mlir::UnitAttrunit attribute
transpose::mlir::DenseI64ArrayAttri64 dense array attribute
l1_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l2_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l3_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
offsetsvariadic of index

Results: 

ResultDescription
valuefixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values

xegpu.mem_desc_subview (xegpu::MemDescSubviewOp) 

Syntax:

operation ::= `xegpu.mem_desc_subview` $src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
              attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))

Creates a subview of a memory descriptor. The resulting memory descriptor can have a lower rank than the source; in this case, the result dimensions correspond to the higher-order dimensions of the source memory descriptor.

Arguments:

  • src : a memory descriptor.
  • offsets : the coordinates within the matrix the subview will be created from.

Results:

  • res : a memory descriptor with smaller size.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), ViewLikeOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
srcMemDesc describing the data in SLM
offsetsvariadic of index

Results: 

ResultDescription
resMemDesc describing the data in SLM

xegpu.nbarrier_arrive (xegpu::NbarrierArriveOp) 

It signals the arrival at the named barrier.

Syntax:

operation ::= `xegpu.nbarrier_arrive` $nbarrier attr-dict `:` qualified(type($nbarrier))

NbarrierArriveOp signals the hardware (or other threads) that the current thread has produced its data for the consumer threads. When the hardware signalled by participant_thread_num threads for the named barrier, it will notify the threads waiting for the named barrier to continue their work.

Operands: 

OperandDescription
nbarrier!xegpu.nbarrier a custom XeGPU type representing a barrier.

xegpu.nbarrier_wait (xegpu::NbarrierWaitOp) 

It waits for a named barrier.

Syntax:

operation ::= `xegpu.nbarrier_wait` $nbarrier attr-dict `:` qualified(type($nbarrier))

NbarrierWaitOp signals the hardware which named barrier the current thread is waiting for, such that it can get notified when the named barrier is completed.

Operands: 

OperandDescription
nbarrier!xegpu.nbarrier a custom XeGPU type representing a barrier.

xegpu.prefetch (xegpu::PrefetchOp) 

Prefetches a set of scattered data points to cache

Syntax:

operation ::= `xegpu.prefetch` $source
              (`[` $offsets^ `]`)?
              prop-dict
              attr-dict `:` type(operands)

It issues instructions to prefetch a set of scattered data points from memory to each level of the cache based on their cache policy. As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead.

Arguments:

  • source: represents the memory region to be loaded from, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. tensor_desc cannot be used in SIMT mode.
  • offsets: represents offsets from source. required if source in not a TensorDescType. offsets is a vector of index type and vector length is either the subgroup size or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
  • l1_hint, l2_hint, l3_hint: are optional cache hints for each level of cache.
  • offset_align_byte: required if source is a pointer. If source is not a pointer, it is not allowed. Represents the alignment in bytes of each offset in offsets.

Example 1:

  xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>}
    : !xegpu.tensor_desc<16xf16>

Example 2: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines “create scattered TensorTdesc” and “prefetch with scattered TensorTdesc”. The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc for the restriction of memref.

  %a = memref.alloc() : memref<1024xf32>
  %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
  xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>}
    : memref<1024xf32>, vector<4xindex>

Example 3 (SIMT mode): SIMT mode only accepts the offsets variant.

  xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>}
    : memref<256xf32>, vector<1xindex>

Example 4 (SIMT mode): SIMT mode only accepts the offsets variant.

  xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>,
                         offset_align_byte = 2}
    : i64, vector<1xindex>

Attributes: 

AttributeMLIR TypeDescription
l1_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l2_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l3_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
offset_align_byte::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
sourceTensorDesc describing regions of interested data. or 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsfixed-length vector of index values or index

xegpu.prefetch_nd (xegpu::PrefetchNdOp) 

Prefetches a n-D block to cache

Syntax:

operation ::= `xegpu.prefetch_nd` $TensorDesc ``
              custom<OptionalDynamicIndexList>($offsets, $const_offsets)
              prop-dict attr-dict `:` qualified(type($TensorDesc))

It issues an instruction to prefetch a block of data from continuous memory regions to each level of the cache based on their cache policy.

Example:

  xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                            l2_hint = #xegpu.cache_hint<cached>,
                            l3_hint = #xegpu.cache_hint<cached>}
    : !xegpu.tensor_desc<8x16xf16>

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
l1_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l2_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l3_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
offsetsvariadic of index

xegpu.store (xegpu::StoreScatterOp) 

Store data to scattered memory locations.

Syntax:

operation ::= `xegpu.store` $value `,`
              $dest
              (`[` $offsets^ `]`)? `,`
              $mask
              prop-dict
              attr-dict `:`  type(operands)

It (aka. store) stores data to scattered memory locations. The value is typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes and the dim-0 of the value corresponds to the chunk size stored per lane. So store_scatter has transpose effect, which is similar to load_gather. Therefore, a transpose attribute is introduced on purpose, making sure users are aware of this implicit transformation.

In SIMT mode, the result is a 1D vector that represents the data to be stored by each work-item. If size is not 1, size should be equal to the chunk size.

Arguments:

  • value: represents the data to be stored.
  • dest: represents the memory region to be stored to, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. tensor_desc cannot be used in SIMT mode.
  • offsets: represents offsets from dest. required if source in not a TensorDescType. offsets is a vector of index type and vector length is either the subgroup size or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
  • mask: is a vector of i1 type, which is used to mask out the memory access. mask is a vector of size equal to the subgroup size, or 1 in SIMT mode. scalar mask is also valid for SIMT mode.
  • chunk_size: (optional) represents contiguous number of elements to store to per work item.
  • l1_hint, l2_hint, l3_hint: are optional cache hints for each level of cache.

Example 1:

  xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
                           l2_hint = #xegpu.cache_hint<write_back>,
                           l3_hint = #xegpu.cache_hint<write_through>}>
        : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>

Example 2:

  xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
                           l2_hint = #xegpu.cache_hint<write_back>,
                           l3_hint = #xegpu.cache_hint<write_through>}>
        : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>

Example 3: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines “create scattered TensorTdesc” and “store with scattered TensorTdesc”. The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.

  %a = memref.alloc() : memref<1024xf32>
  %val = arith.constant dense<0.0> : vector<16xf32>
  %offsets = vector.step : vector<16xindex>
  %mask = vector.constant_mask [16]: vector<16xi1>
  xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<cached>,
                         l3_hint = #xegpu.cache_hint<cached>}
    : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>

Example 4 (SIMT mode): SIMT mode only accepts the offsets variant. chunk_size can be inferred from value type. In this example, chunk_size is 8.

  xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>,
                           l2_hint = #xegpu.cache_hint<write_back>,
                           l3_hint = #xegpu.cache_hint<write_through>}>
        : vector<8xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
chunk_size::mlir::IntegerAttr64-bit signless integer attribute
l1_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l2_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l3_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators

Operands: 

OperandDescription
valuefixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type
destTensorDesc describing regions of interested data. or 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer
offsetsfixed-length vector of index values or index
maskfixed-length vector of 1-bit signless integer values or 1-bit signless integer

xegpu.store_matrix (xegpu::StoreMatrixOp) 

Syntax:

operation ::= `xegpu.store_matrix` $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
              prop-dict attr-dict `` `:` type(operands)

This operation stores a 2D data fragment into the shared local memory region specified by a 2D mem_desc. Only 2D memory descriptors are supported; use the subview operation to obtain a 2D mem_desc from a higher-rank descriptor if needed.

Arguments:

  • mem_desc: the memory descriptor specifying the SLM region.
  • offsets: the coordinates within the matrix where the data will be written.
  • data: the values to be stored in the matrix.
  • layout: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr.

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
layout::mlir::xegpu::DistributeLayoutAttr
DistributeLayoutAttr instance
Common trait for all XeGPU layouts.

Operands: 

OperandDescription
datafixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values
mem_descMemDesc describing the data in SLM
offsetsvariadic of index

xegpu.store_nd (xegpu::StoreNdOp) 

Stores a n-D block register region back to memory, currently only supports 2D

Syntax:

operation ::= `xegpu.store_nd` $value `,`
              $TensorDesc ``
              custom<OptionalDynamicIndexList>($offsets, $const_offsets)
              prop-dict attr-dict `:`  type($value) `,` qualified(type($TensorDesc))

StoreNdOp essentially mimics the hardware block write instruction io write a block of data from register into the memory region as described by the TensorDesc. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. It is only available to 1D or 2D blocked tensor_desc.

In SIMT mode, the input vector represents the data to be stored by each work-item.

Example 1:

  xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                         l2_hint = #xegpu.cache_hint<write_back>,
                         l3_hint = #xegpu.cache_hint<write_through>}
                         : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>

Example 2 (SIMT mode):

  xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                         l2_hint = #xegpu.cache_hint<write_back>,
                         l3_hint = #xegpu.cache_hint<write_through>}
                         : vector<8xf16>, !xegpu.tensor_desc<8x16xf16>

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute
l1_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l2_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators
l3_hint::mlir::xegpu::CachePolicyAttrDescribe the cache settings for prefetch/load/store operators

Operands: 

OperandDescription
valuefixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values
TensorDescTensorDesc describing regions of interested data.
offsetsvariadic of index

xegpu.update_nd_offset (xegpu::UpdateNdOffsetOp) 

It updates the offsets for the TensorDesc.

Syntax:

operation ::= `xegpu.update_nd_offset` $TensorDesc `,`
              custom<DynamicIndexList>($offsets, $const_offsets)
              attr-dict `:` qualified(type($result))

The op updates the offset of the given TensorDesc. The offsets are relative offset to the current position in the number of elements. It will result in a same type TensorDesc as the input.

Example:

  %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
const_offsets::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
offsetsvariadic of index

Results: 

ResultDescription
resultTensorDesc describing regions of interested data.

xegpu.update_offset (xegpu::UpdateOffsetOp) 

It updates the offsets for the given tensor descriptor

Syntax:

operation ::= `xegpu.update_offset` $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)

It behaves similar to update_nd_offset in terms that it updates offset of a TensorDesc, and the offsets are relative offset to the current position in the number of elements. However, update_nd_offset is to update the start point of a 2D block, so its offset constains two elements representing the shift in each dimension. update_offset is to update the offset per work-item, so its offsets contains values representing shifts for each work-item.

Example:
```mlir
  %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
  %2 = xegpu.update_offset %1, %off :
          !xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>>, vector<4xindex>
```

Operands: 

OperandDescription
TensorDescTensorDesc describing regions of interested data.
offsetsfixed-length vector of index values

Results: 

ResultDescription
resultTensorDesc describing regions of interested data.

Attributes 

BlockTensorDescAttr 

A composite attribute for TensorDescType

Syntax:

#xegpu.block_tdesc_attr<
  MemorySpaceAttr,   # memory_space
  IntegerAttr,   # array_length
  BoolAttr   # boundary_check
>

BlockTensorDesc (or block_tdesc_attr) is a composite attribute defined for TensorDescType for describing following properties of a TensorDesc. 1. memory_space: It describes where the data block described by the TensorDesc is located, Global device memory or Shared local memory. It is default to Global. 2. array_length: It describes how many horizontally consecutive blocks will be loaded by a hardware load instruction. If the TensorDesc shape is 8x16, with array_length = 2. The loaded block shape will be actually 8x32. Its default value is 1. 3. boundary_check: It is used to indicates the hardware whether to do out-of-boundary check. The default value is true.

Parameters: 

ParameterC++ typeDescription
memory_spaceMemorySpaceAttrData memory location
array_lengthIntegerAttrNumber of continuous blocks to load
boundary_checkBoolAttrChecking the out of boundary access

CachePolicyAttr 

Describe the cache settings for prefetch/load/store operators

Syntax:

#xegpu.cache_hint<
  ::mlir::xegpu::CachePolicy   # value
>

Parameters: 

ParameterC++ typeDescription
value::mlir::xegpu::CachePolicyan enum of type CachePolicy

FenceScopeAttr 

Describes the scope of fence. “workgroup” means that the scope is within each work group. “gpu” means the scope is across work groups within the gpu.

Syntax:

#xegpu.fence_scope<
  ::mlir::xegpu::FenceScope   # value
>

Parameters: 

ParameterC++ typeDescription
value::mlir::xegpu::FenceScopean enum of type FenceScope

LayoutAttr 

Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor.

Syntax:

#xegpu.layout<
  DenseI32ArrayAttr,   # sg_layout
  DenseI32ArrayAttr,   # sg_data
  DenseI32ArrayAttr,   # inst_data
  DenseI32ArrayAttr,   # lane_layout
  DenseI32ArrayAttr,   # lane_data
  DenseI32ArrayAttr   # order
>

XeGPU operations use LayoutAttr to define how data is distributed across subgroups and work-items. This attribute is specified in tensor descriptors during tensor description creation. LayoutAttr includes the following parameters:

  • sg_layout: Specifies the total number of subgroups and their layout within a workgroup. It is mandatory for workgroup-level programming. Its presence implies workgroup-level code.
  • sg_data: Defines the data size accessed per subgroup. It is optionally used with sg_layout for workgroup-level programming. When it is left empty, the size accessed per subgroup can be derived from the tensor shape and sg_layout using the formula: sg_data[i] = tensor_shape[i] / sg_layout[i].
  • inst_data: Specifies the data size that is processed by an instruction. It is optionally used with lane_layout. When it is left empty, the data size per instruction is equivalent to the sg_data for workgroup-level programming or equivalent to tensor shape for subgroup-level programming.
  • lane_layout : Specifies the total number of work-items and their arrangement within a subgroup. It is mandatory for subgroup-level programming and optional for workgroup-level programming.
  • lane_data : Specifies the shape of the tensor fragment that each lane accesses. It defines a single, minimal distribution unit. Processing the entire tensor may require one or more distribution units per hardware instruction.
  • order: Specifies the dimension order used to linearize n-dimensional sg_layout and lane_layout to 1-dimensional layout. The first dimension in the order list is the fastest-changing dimension. If it is not present, the default value is [1, 0].

Examples:

  1. Subgroup level layout:
#xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>

In this example, there are 16 work-items per subgroup, and is organized as [[0, 1, 2, .., 7],[8, 9, .., 15]]. The distribution unit is 1x1.

  1. Subgroup level layout with order:
#xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1], order = [0, 1]>

In this example, there are 16 work-items per subgroup, and is organized as [[0, 2, 4, …, 14], [1, 3, 5, …, 15]]. The distribution unit is 1x1.

  1. Subgroup level layout with inst_data
#xegpu.layout<inst_data = [8, 16], lane_layout = [2, 8], lane_data = [2, 2]>

In this example, the original problem size is partitioned into smaller subproblems of dimensions [8, 16], which are then distributed among 16 work-items arranged as [[0, 1, 2, …, 7], [8, 9, …, 15]]. Each work-item is assigned four 2x2 blocks in a round-robin manner.

  1. Workgroup level layout:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [2, 8], lane_data = [1, 1]>

In this example, the layout represents a workgroup distribution. A workgroup consists of 8 subgroups arranged as [[0, 1, 2, 3], [4, 5, 6, 7]]. Each subgroup accesses a 16x16 block per instruction, which is further distributed to 16 work items which is organized as [[0, 1, 2, .., 7],[8, 9, .., 15]].

  1. Workgroup level layout with order:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [2, 8], lane_data = [1, 1], order = [0, 1]>

In this example, the layout represents a workgroup distribution. A workgroup consists of 8 subgroups arranged as [[0, 2, 4, 6], [1, 3, 5, 7]]. Each subgroup accesses a 16x16 block per instruction, which is further distributed to 16 work items which is organized as [[0, 2, 4, …, 14], [1, 3, 5, …, 15]].

  1. Workgroup level layout with inst_data:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], inst_data = [8, 16], lane_layout = [2, 8], lane_data = [1, 1]>

This example is similar to the previous ones, but the inst_data parameter divides sg_data into two instructions, each processing an 8x16 block. These blocks are further distributed across 16 work-items with a distribution unit of 1x1. Unlike the 2x2 distribution unit in example 3, which results in accessing contiguous 2x2 blocks, the 1x1 distribution unit may result in non-contiguous access.

Parameters: 

ParameterC++ typeDescription
sg_layoutDenseI32ArrayAttr
sg_dataDenseI32ArrayAttr
inst_dataDenseI32ArrayAttr
lane_layoutDenseI32ArrayAttr
lane_dataDenseI32ArrayAttr
orderDenseI32ArrayAttr

MemLayoutAttr 

Specifies memory layouts with named attributes.

This attribute stores a collection of named attributes that describe memory layout properties such as stride, block, etc.

Parameters: 

ParameterC++ typeDescription
attrsDictionaryAttr

MemorySpaceAttr 

Describe the location of data described by a TensorDesc: Global device memory (Global) or Shared local memory (SLM).

Syntax:

#xegpu.memory_space<
  ::mlir::xegpu::MemorySpace   # value
>

Parameters: 

ParameterC++ typeDescription
value::mlir::xegpu::MemorySpacean enum of type MemorySpace

RangeAttr 

Specifies a half-open range

Syntax:

#xegpu.range<
  IntegerAttr,   # start
  IntegerAttr   # end
>

RangeAttr is an attribute that defines a half-open range [start, end). The range is inclusive of the start value and exclusive of the end value. One usage of this attribute can be to specify the subgroup id range. The subgroup id range can be specified using this attribute, and it can be attached to a scf.if op like

scf.if %cond {
  // some operations
} {sg_id_range = #xegpu.range<[2, 4]>}

In this case, the scf.if op will only be executed for subgroup IDs 2 and 3.

Parameters: 

ParameterC++ typeDescription
startIntegerAttr
endIntegerAttr

ScatterTensorDescAttr 

A composite attribute for TensorDescType

Syntax:

#xegpu.scatter_tdesc_attr<
  MemorySpaceAttr,   # memory_space
  IntegerAttr   # chunk_size
>

ScatterTensorDesc is a composite attribute defined for TensorDescType for describing following properties of a TensorDesc:

  1. memory_space: It describes where the data block described by the TensorDesc is located, Global device memory or Shared local memory. It is default to Global.

  2. chunk_size: Specifies the number of contiguous elements accessed per offset. The default value is 1.

Parameters: 

ParameterC++ typeDescription
memory_spaceMemorySpaceAttrData memory location
chunk_sizeIntegerAttrNumber of contiguous elements

SliceAttr 

Describes the data distribution and sharing among subgroups or work-items.

Syntax:

#xegpu.slice<
  xegpu::DistributeLayoutAttr,   # parent
  DenseI64ArrayAttr   # dims
>

Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items. However, whereas LayoutAttr requires the data to have the same rank as the attribute, SliceAttr permits the data to have a lower rank. In this case, compute units in the specified dimensions (given by $dims) share the data, provided that the remaining ranks match the data rank. SliceAttr is commonly used by operations such as vector.multi_reduction and vector.broadcast.

Example:

#l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
#r = #xegpu.slice<#l, dim = [0]>

%exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
%red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
%bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>

In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a single reduction result of type vector<32xf32>.

Parameters: 

ParameterC++ typeDescription
parentxegpu::DistributeLayoutAttr
dimsDenseI64ArrayAttr

Types 

MemDescType 

MemDesc describing the data in SLM

MemDesc represents a block of data stored in shared local memory. By default, unless a layout attribute is provided, the data is stored contiguously in row-major order within the region.

Examples:

// A multi-dimensional array stored in column-major order.
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128]>>

// A multi-dimensional array stored in a blocked layout. Elements within the same block
// are stored contiguously in memory. Blocks are stored in row-major order.
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<block = [8, 8]>>

// A multi-dimensional array stored in column-major order with blocked layout.
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [8, 8]>>

Parameters: 

ParameterC++ typeDescription
shape::llvm::ArrayRef<int64_t>
elementTypemlir::Type
mem_layoutMemLayoutAttr

NbarrierType 

!xegpu.nbarrier a custom XeGPU type representing a barrier.

Syntax: !xegpu.nbarrier

TensorDescType 

TensorDesc describing regions of interested data.

TensorDesc is a type designed to describe regions of interest in data, as well as some features unique to Intel hardware. Unlike the built-in tensor type in MLIR, it essentially contains only metadata and does not hold the data itself. It is primarily designed to support 2D block load/store and DPAS (matrix multiplication instruction) on Intel GPUs. It encodes the following information:

  • shape: the sizes/shape of the interested data block, e.g., 8x16 means 8 rows and each row contains 16 contiguous data element. The rows could be either contiguous or not, depends on the encoding attribute. If the encoding is a BlockTensorDescAttr, rows are contiguous. If the encoding is a ScatterTensorDescAttr, rows are not necessary to be contiguous. If encoding is not set, it is considered as a default BlockTensorDescAttr.

  • element_type: the data type of the data element, e.g., f16, f32.

Similar to the built-in tensor, it also provides optional attributes for encoding additional information via either BlockTensorDescAttr or ScatterTensorDescAttr, or supporting Workgroup, Subgroup, and workitem (or SIMT) level programmings via the Layout attribute. Please check their definition for details.

Syntax:

TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, encoding-attr)? (, layout-attr)?
enconding-attr = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
layout-attr = (, layout `<`sg_layout = value, sg_data = value, inst_data = value, lane_layout = value, lane_data = value, order = value`>`)?

Examples:

// A block TensorDesc with 8x16 i32 elements
xegpu.tensor_desc<8x16xi32>

// A block TensorDesc with 8x16 f32 elements
xegpu.tensor_desc<8x16xf32>

// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>

// A 1D TensorDesc with a layout for subgroup level programming, each lane access two continuous elements
xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [2]>>

// A 1D TensorDesc with a layout for subgroup level programming, each lane access two elements with stride = 16
xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>

// A TensorDesc with a layout for subgroup level programming
xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>

// A TensorDesc with a layout for workgroup level programming
xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>

// A TensorDesc with a layout for workgroup level programming without lane_layout and lane_data
xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16]>>

Parameters: 

ParameterC++ typeDescription
shape::llvm::ArrayRef<int64_t>
elementTypemlir::Type
encodingmlir::Attribute
layoutmlir::Attribute

Enums 

CmpFPredicate 

Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15

Cases: 

SymbolValueString
AlwaysFalse0false
OEQ1oeq
OGT2ogt
OGE3oge
OLT4olt
OLE5ole
ONE6one
ORD7ord
UEQ8ueq
UGT9ugt
UGE10uge
ULT11ult
ULE12ule
UNE13une
UNO14uno
AlwaysTrue15true

CmpIPredicate 

Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9

Cases: 

SymbolValueString
eq0eq
ne1ne
slt2slt
sle3sle
sgt4sgt
sge5sge
ult6ult
ule7ule
ugt8ugt
uge9uge

IntegerOverflowFlags 

Integer overflow arith flags

Cases: 

SymbolValueString
none0none
nsw1nsw
nuw2nuw

RoundingMode 

Floating point rounding mode

Cases: 

SymbolValueString
to_nearest_even0to_nearest_even
downward1downward
upward2upward
toward_zero3toward_zero
to_nearest_away4to_nearest_away

AtomicRMWKind 

Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15

Cases: 

SymbolValueString
addf0addf
addi1addi
andi2andi
assign3assign
maximumf4maximumf
maxnumf5maxnumf
maxs6maxs
maxu7maxu
minimumf8minimumf
minnumf9minnumf
mins10mins
minu11minu
mulf12mulf
muli13muli
ori14ori
xori15xori

FastMathFlags 

Floating point fast math flags

Cases: 

SymbolValueString
none0none
reassoc1reassoc
nnan2nnan
ninf4ninf
nsz8nsz
arcp16arcp
contract32contract
afn64afn
fast127fast

CachePolicy 

Cache policy

Cases: 

SymbolValueString
CACHED0cached
UNCACHED1uncached
STREAMING2streaming
READ_INVALIDATE3read_invalidate
WRITE_BACK4write_back
WRITE_THROUGH5write_through

FenceScope 

The enumeration for the scope of fence operation.

Cases: 

SymbolValueString
Workgroup0workgroup
GPU1gpu

MemorySpace 

The address space of the memory the tensor descritor is created for

Cases: 

SymbolValueString
Global0global
SLM3slm