'xegpu' Dialect
The XeGPU dialect that models Intel GPU’s ISA
The XeGPU dialect closely models a subset of the Xe GPU’s ISA, providing an abstraction to support high-performance GEMM code generation. It serves as a bridge dialect in the MLIR gradual lowering process, working with MLIR memref and vector types, and complements the Arith, Math, Vector, and Memref dialects. XeGPU operations are introduced for special Xe instructions not modeled by the LLVM/SPIR-V dialect, such as DPAS and 2D block load and store.
It supports a tile-based programming model, decomposing the GEMM kernel into large predefined tile sizes at the subgroup and workgroup levels. XeGPU allows the high-level GEMM algorithm to be easily expressed. Underneath, it uses target-specific recipes and hardware features to achieve optimal performance on specific hardware. By decomposing GEMM at submatrix granularity and mapping it to registers, it naturally supports optimizations like fusing with neighboring operations.
Operations ¶
xegpu.alloc_nbarrier
(xegpu::AllocNbarrierOp) ¶
It allocates a set of named barriers.
Syntax:
operation ::= `xegpu.alloc_nbarrier` $nbarrier_num attr-dict
AllocNbarrier is to create a set of named barriers as
specified by nbarrier_num
. Named barriers are workgroup level resources,
and are shared by all threads in the workgroup. For example, there are
up to 32 barriers (range 0-31) for each XeCore on PVC. A typical use case
is that a workgroup is partitioned into N subgroups of threads (N <= 32),
and each subgroup coordinating their work with a separate barrier with id
range from 0 to N respectively.
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
nbarrier_num | ::mlir::IntegerAttr | 64-bit signless integer attribute |
xegpu.atomic_rmw
(xegpu::AtomicRMWOp) ¶
Atomic read-modify-write operation on the TensorDesc.
Syntax:
operation ::= `xegpu.atomic_rmw` $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
The xegpu.atomic_rmw
operation provides a way to perform a read-modify-write
operation on the region described by the TensorDesc
free from data races. The
kind
enumeration specifies the modification to be performed, The mask
operand
has the same shape with TensorDesc
, and is used to enable or disable specific
data points of the TensorDesc
. The value
operand represents the new value to
be applied during the modification.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, MemoryEffectOpInterface (MemoryEffectOpInterface)
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource, MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
, MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
kind | ::mlir::arith::AtomicRMWKindAttr | allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 |
Operands: ¶
Operand | Description |
---|---|
tensorDesc | TensorDesc describing regions of interested data. |
mask | fixed-length vector of 1-bit signless integer values |
value | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values |
Results: ¶
Result | Description |
---|---|
result | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values |
xegpu.convert_layout
(xegpu::ConvertLayoutOp) ¶
Convert the layout of the input operand
Syntax:
operation ::= `xegpu.convert_layout` $source prop-dict attr-dict `:` type($source)
convert_layout
redistribute data across subgroups and/or work-items from the input_layout
to
the target_layout
. Both input_layout
and target_layout
must correspond to the same programming
scope, such as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once
the IR is lowered to WI level because that is the end result of all distributions.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
input_layout | ::mlir::xegpu::DistributeLayoutAttr | DistributeLayoutAttr instance
|
target_layout | ::mlir::xegpu::DistributeLayoutAttr | DistributeLayoutAttr instance
|
Operands: ¶
Operand | Description |
---|---|
source | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4/5/6 |
Results: ¶
Result | Description |
---|---|
result | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4/5/6 |
xegpu.create_mem_desc
(xegpu::CreateMemDescOp) ¶
Create a memory descriptor.
Syntax:
operation ::= `xegpu.create_mem_desc` $source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))
Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu specific memory layout. The resulting memory descriptor has to have the same size as the underlying shared local memory.
Arguments:
source
: a 1D statically shaped memref with element type i8, representing the raw SLM buffer. Results:mem_desc
: the memory descriptor.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
source | statically shaped memref of 8-bit signless integer values for shared memory |
Results: ¶
Result | Description |
---|---|
mem_desc | MemDesc describing the data in SLM |
xegpu.create_nd_tdesc
(xegpu::CreateNdDescOp) ¶
Create nd-tensor descriptor operation
Syntax:
operation ::= `xegpu.create_nd_tdesc` $source ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
(`,` `shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
`,` `strides``:` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
The “create_nd_tdesc” operation creates a TensorDescType which represents a sub-view of a 1D/2D memory region inside the one or two innermost dimensions of the source. (It can be extended to support n-D memory region if needed in future). Elements in the subview continuous in each dimension. It encodes the following important information for supporting Intel hardware features:
Arguments:
source
: an object representing (starting address/pointer of) a memory region. It can be either a memref object, or simply a pointer represented by uint64_t type. For the case of dynamic memrefs or pointer, the shape and layout information of the memory region should be explicitly passed viashape
andstrides
parameters.offsets
: index values represents offsets from the “source” at the each dimension at which the subview of the target memory will be created. It is encoded via “offsets” and “const_offsets”, such that it can accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).shape
: the shape information of the memory region pointed by the “source”. It is typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the shape information has to be explicitly passed via the “shape” and “const_shape” arguments.strides
: the strides of the memory region pointed by the “source”. Similar to shape, it is typically encoded via the MemRefType of the source too. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the strides information has to be explicitly passed via the “strides” and “const_strides” argument.
Results:
res
: nd tensor descriptor
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = memref.alloc() : memref<1024x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>
Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = memref.alloc(%h, %w) : memref<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = ... : ui64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
const_shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
const_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
source | non-0-ranked.memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | variadic of index |
shape | variadic of index |
strides | variadic of index |
Results: ¶
Result | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.create_tdesc
(xegpu::CreateDescOp) ¶
Create scattered tensor descriptors (TensorDesc).
Syntax:
operation ::= `xegpu.create_tdesc` $source `,` $offsets attr-dict `:` type($source) `,` type($offsets) `->` qualified(type($TensorDesc))
“create_tdesc” is similar to “create_nd_tdesc” in terms that it creates a Tensor Descriptor (TensorDescType) for a memory region. While “create_nd_tdesc” is for creating continuous subviews, “create_tdesc” is for creating non-continuous (scattered) subviews, allowing each work-item in a subgroup specifying their own offset. It accepts the following parameters:
Arguments:
source
: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened memory object.offsets
: a vector containing offsets of each access point. Its size is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, implying each element in the vector corresponds to a work-item (SIMT lane) in the subgroup.
Results:
res
: scattered tensor descriptor
The first dimension of the result TensorDesc corresponds to work-items, so it should match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1.
Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>
Example 2: It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
Example 3: It is similar to Example 2, but there is some overlaps among workitems. It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
source | 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | fixed-length vector of index values |
Results: ¶
Result | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.dpas
(xegpu::DpasOp) ¶
It performs mma computation
Syntax:
operation ::= `xegpu.dpas` $lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)? `->` type($result)
DPAS performs matrix multiplication on matrix A of mxk
size, B of kxn
size, and accumulate on matrix C of mxn
to the same size
matrix , m=8
, n=16
and k=8 * 32/bit_width_of_elem_type
. So for fp16
data type, the matrices are A: vector<8x16xf16>
, B: vector<16x16xf16>
,
and C/D: vector<8x16xf32>
. Besides the matrix size requirements, DPAS
also requires A and B to be loaded with the required data layout. Specially,
VNNI layout is required for B operand. It is achieved via adding packed
attribute to the load_nd
operator. Due to the VNNI transformation, B operands
can be represented as a 3D vector, with the last dimension representing the VNNI
factor, which is computed as 32/bit_width_of_elem_type
. Thus, B: vector<16x16xf16>
can be represented as B: vector<8x16x2xf16>
.
In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
(https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
for more details about the fragment distribution.
Note: on PVC, the hardware can perform load with VNNI transformation when data
element type is 16-bit or lower precision, taking 2 or 4 elements from
the first dimension and inserted into the newly added innermost dimension.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
lhs | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3 |
rhs | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3 |
acc | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2 |
Results: ¶
Result | Description |
---|---|
result | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2 |
xegpu.fence
(xegpu::FenceOp) ¶
It synchronizes memory accesses.
Syntax:
operation ::= `xegpu.fence` `memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict
It synchronizes the memory access between
write and following read or write.
1. Memory_kind
describes the memory kind. “global” means the global memory,
“slm” means the share local memory.
2. Fence_scope
describes the scope of fence. “Workgroup” means that the scope would be
within each workgroup. “GPU” means the scope would be across workgroups within the GPU.
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
memory_kind | ::mlir::xegpu::MemorySpaceAttr | Describe the location of data described by a `TensorDesc`: Global device memory (`Global`) or Shared local memory (`SLM`). |
fence_scope | ::mlir::xegpu::FenceScopeAttr | Describes the scope of fence. "workgroup" means that the scope is within each work group. "gpu" means the scope is across work groups within the gpu. |
xegpu.init_nbarrier
(xegpu::InitNbarrierOp) ¶
It assigns a named barrier to the current thread.
Syntax:
operation ::= `xegpu.init_nbarrier` $nbarrier_id `,` $participant_thread_num attr-dict `:`
type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))
InitNbarrierOp assigns the named barrier with the specified
barrier ID (0~31) to the current thread. Multiple threads may bind to the
same named barrier, and the participant_thread_num
specifies the total
number of threads associated with the nbarrier. It returns an object of
NbarrierType representing the barrier
Operands: ¶
Operand | Description |
---|---|
nbarrier_id | 8-bit signless integer |
participant_thread_num | 8-bit signless integer |
Results: ¶
Result | Description |
---|---|
result | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.load
(xegpu::LoadGatherOp) ¶
Load a set of scattered data points from memory.
Syntax:
operation ::= `xegpu.load` $source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
It (aka. load) load data per each work-item. The output describes the data being loaded at the subgroup level, so its size is consistent with the number of work-items in a subgroup. When the chunk size is larger than 2, the output vector is a 2D vector, with dim-0 correspoding to work-items, and dim-1 corresponding to the chunk size loaded by each work-item. The mask operand masks out memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
In SIMT mode, the result is a 1D vector that represents the data to be loaded by each work-item. If size is not 1, size should be equal to the chunk size,
Arguments:
source
: represents the memory region to be loaded from, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. tensor_desc cannot be used in SIMT mode.offsets
: represents offsets from source. required ifsource
in not a TensorDescType. offsets is a vector ofindex
type and vector length is either the subgroup size or 1 in SIMT mode. scalar offset is also valid for SIMT mode.mask
: is a vector ofi1
type, which is used to mask out the memory access. mask is a vector of size equal to the subgroup size, or 1 in SIMT mode. scalar mask is also valid for SIMT mode.chunk_size
: (optional) represents contiguous number of elements to load from per work item.l1_hint
,l2_hint
,l3_hint
: are optional cache hints for each level of cache.
Results:
res
: represents loaded data
Example 1:
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
vector<16xi1> -> vector<16xf32>
Example 2:
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
Example 3: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines “create scattered TensorTdesc” and “load with scattered TensorTdesc”. The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc for the restriction of memref.
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
%val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
Example 4 (SIMT mode): SIMT mode only accepts the offsets variant. chunk_size can be inferred from result type. In this example, chunk_size is 8.
%2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}>
: memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
chunk_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
source | TensorDesc describing regions of interested data. or 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | fixed-length vector of index values or index |
mask | fixed-length vector of 1-bit signless integer values or 1-bit signless integer |
Results: ¶
Result | Description |
---|---|
value | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.load_matrix
(xegpu::LoadMatrixOp) ¶
Syntax:
operation ::= `xegpu.load_matrix` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
This operation loads a 2D block of data from shared local memory (SLM) as specified
by the provided 2D mem_desc
. Only 2D memory descriptors are supported; use the
subview operation to obtain a compatible 2D mem_desc
from a higher-rank descriptor if needed.
Arguments:
mem_desc
: the memory descriptor identifying the SLM region.offsets
: the coordinates within the matrix to read from.layout
: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr. Results:res
: the matrix elements loaded from SLM.
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
layout | ::mlir::xegpu::DistributeLayoutAttr | DistributeLayoutAttr instance
|
Operands: ¶
Operand | Description |
---|---|
mem_desc | MemDesc describing the data in SLM |
offsets | variadic of index |
Results: ¶
Result | Description |
---|---|
res | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values |
xegpu.load_nd
(xegpu::LoadNdOp) ¶
Loads a n-D block from memory (represented by TensorDesc)to registers (represented by vector)
Syntax:
operation ::= `xegpu.load_nd` $TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
LoadNdOp essentially mimics the hardware block read instruction to read a block of data from memory to register. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. VNNI transformation is an hardware feature for Intel GPU, which is used to do data packing during the load for B operand of matrix operation, if the bit width of the data type is less then 32 bits, e.g., fp16. And transpose is another Intel hardware feature, which will do transpose operation when loading the data if the bit width of the data type is fp32 or fp64. It implies that vnni and transpose cannot exit at the same time. It is only available to 1D or 2D blocked tensor_desc.
In SIMT mode, result vector represents the data to be loaded by each work-item.
Example 1:
xegpu.load_nd %1 {transpose = [1, 0],
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<streaming>}
: !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
Example 2 (SIMT mode):
xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
packed | ::mlir::UnitAttr | unit attribute |
transpose | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | variadic of index |
Results: ¶
Result | Description |
---|---|
value | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values |
xegpu.mem_desc_subview
(xegpu::MemDescSubviewOp) ¶
Syntax:
operation ::= `xegpu.mem_desc_subview` $src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))
Creates a subview of a memory descriptor. The resulting memory descriptor can have a lower rank than the source; in this case, the result dimensions correspond to the higher-order dimensions of the source memory descriptor.
Arguments:
src
: a memory descriptor.offsets
: the coordinates within the matrix the subview will be created from.
Results:
res
: a memory descriptor with smaller size.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
src | MemDesc describing the data in SLM |
offsets | variadic of index |
Results: ¶
Result | Description |
---|---|
res | MemDesc describing the data in SLM |
xegpu.nbarrier_arrive
(xegpu::NbarrierArriveOp) ¶
It signals the arrival at the named barrier.
Syntax:
operation ::= `xegpu.nbarrier_arrive` $nbarrier attr-dict `:` qualified(type($nbarrier))
NbarrierArriveOp signals the hardware (or other threads)
that the current thread has produced its data for the consumer threads. When
the hardware signalled by participant_thread_num
threads for the named barrier,
it will notify the threads waiting for the named barrier to continue their work.
Operands: ¶
Operand | Description |
---|---|
nbarrier | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.nbarrier_wait
(xegpu::NbarrierWaitOp) ¶
It waits for a named barrier.
Syntax:
operation ::= `xegpu.nbarrier_wait` $nbarrier attr-dict `:` qualified(type($nbarrier))
NbarrierWaitOp signals the hardware which named barrier the current thread is waiting for, such that it can get notified when the named barrier is completed.
Operands: ¶
Operand | Description |
---|---|
nbarrier | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.prefetch
(xegpu::PrefetchOp) ¶
Prefetches a set of scattered data points to cache
Syntax:
operation ::= `xegpu.prefetch` $source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
It issues instructions to prefetch a set of scattered data points from memory to each level of the cache based on their cache policy. As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead.
Arguments:
source
: represents the memory region to be loaded from, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. tensor_desc cannot be used in SIMT mode.offsets
: represents offsets from source. required ifsource
in not a TensorDescType. offsets is a vector ofindex
type and vector length is either the subgroup size or 1 in SIMT mode. scalar offset is also valid for SIMT mode.l1_hint
,l2_hint
,l3_hint
: are optional cache hints for each level of cache.offset_align_byte
: required ifsource
is a pointer. Ifsource
is not a pointer, it is not allowed. Represents the alignment in bytes of each offset in offsets.
Example 1:
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
Example 2: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines “create scattered TensorTdesc” and “prefetch with scattered TensorTdesc”. The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc for the restriction of memref.
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<4xindex>
Example 3 (SIMT mode): SIMT mode only accepts the offsets variant.
xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<256xf32>, vector<1xindex>
Example 4 (SIMT mode): SIMT mode only accepts the offsets variant.
xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>,
offset_align_byte = 2}
: i64, vector<1xindex>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
offset_align_byte | ::mlir::IntegerAttr | 64-bit signless integer attribute |
Operands: ¶
Operand | Description |
---|---|
source | TensorDesc describing regions of interested data. or 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | fixed-length vector of index values or index |
xegpu.prefetch_nd
(xegpu::PrefetchNdOp) ¶
Prefetches a n-D block to cache
Syntax:
operation ::= `xegpu.prefetch_nd` $TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc))
It issues an instruction to prefetch a block of data from continuous memory regions to each level of the cache based on their cache policy.
Example:
xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<8x16xf16>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | variadic of index |
xegpu.store
(xegpu::StoreScatterOp) ¶
Store data to scattered memory locations.
Syntax:
operation ::= `xegpu.store` $value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
attr-dict `:` type(operands)
It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
and the dim-0 of the value corresponds to the chunk size stored per lane. So store_scatter
has transpose effect, which is similar to load_gather
. Therefore, a transpose attribute is
introduced on purpose, making sure users are aware of this implicit transformation.
In SIMT mode, the result is a 1D vector that represents the data to be stored by each work-item. If size is not 1, size should be equal to the chunk size.
Arguments:
value
: represents the data to be stored.dest
: represents the memory region to be stored to, which can be either a tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32). In case of tensor_desc, offsets come from the producer create_tdesc op. tensor_desc cannot be used in SIMT mode.offsets
: represents offsets from dest. required ifsource
in not a TensorDescType. offsets is a vector ofindex
type and vector length is either the subgroup size or 1 in SIMT mode. scalar offset is also valid for SIMT mode.mask
: is a vector ofi1
type, which is used to mask out the memory access. mask is a vector of size equal to the subgroup size, or 1 in SIMT mode. scalar mask is also valid for SIMT mode.chunk_size
: (optional) represents contiguous number of elements to store to per work item.l1_hint
,l2_hint
,l3_hint
: are optional cache hints for each level of cache.
Example 1:
xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
Example 2:
xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
Example 3: A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines “create scattered TensorTdesc” and “store with scattered TensorTdesc”. The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
Example 4 (SIMT mode): SIMT mode only accepts the offsets variant. chunk_size can be inferred from value type. In this example, chunk_size is 8.
xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<8xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
chunk_size | ::mlir::IntegerAttr | 64-bit signless integer attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
value | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
dest | TensorDesc describing regions of interested data. or 1D memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | fixed-length vector of index values or index |
mask | fixed-length vector of 1-bit signless integer values or 1-bit signless integer |
xegpu.store_matrix
(xegpu::StoreMatrixOp) ¶
Syntax:
operation ::= `xegpu.store_matrix` $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands)
This operation stores a 2D data
fragment into the shared local memory region
specified by a 2D mem_desc
. Only 2D memory descriptors are supported; use the
subview operation to obtain a 2D mem_desc
from a higher-rank descriptor if needed.
Arguments:
mem_desc
: the memory descriptor specifying the SLM region.offsets
: the coordinates within the matrix where the data will be written.data
: the values to be stored in the matrix.layout
: [optional] An attribute for guiding distributions among subgroups and/or work-items. It currently can accept either LayoutAttr or SliceAttr.
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
layout | ::mlir::xegpu::DistributeLayoutAttr | DistributeLayoutAttr instance
|
Operands: ¶
Operand | Description |
---|---|
data | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values |
mem_desc | MemDesc describing the data in SLM |
offsets | variadic of index |
xegpu.store_nd
(xegpu::StoreNdOp) ¶
Stores a n-D block register region back to memory, currently only supports 2D
Syntax:
operation ::= `xegpu.store_nd` $value `,`
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
StoreNdOp essentially mimics the hardware block write instruction io write a block of data from register into the memory region as described by the TensorDesc. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. It is only available to 1D or 2D blocked tensor_desc.
In SIMT mode, the input vector represents the data to be stored by each work-item.
Example 1:
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
Example 2 (SIMT mode):
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
value | fixed-length vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values |
TensorDesc | TensorDesc describing regions of interested data. |
offsets | variadic of index |
xegpu.update_nd_offset
(xegpu::UpdateNdOffsetOp) ¶
It updates the offsets for the TensorDesc.
Syntax:
operation ::= `xegpu.update_nd_offset` $TensorDesc `,`
custom<DynamicIndexList>($offsets, $const_offsets)
attr-dict `:` qualified(type($result))
The op updates the offset of the given TensorDesc. The offsets are relative offset to the current position in the number of elements. It will result in a same type TensorDesc as the input.
Example:
%2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | variadic of index |
Results: ¶
Result | Description |
---|---|
result | TensorDesc describing regions of interested data. |
xegpu.update_offset
(xegpu::UpdateOffsetOp) ¶
It updates the offsets for the given tensor descriptor
Syntax:
operation ::= `xegpu.update_offset` $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
It behaves similar to update_nd_offset
in terms that
it updates offset of a TensorDesc, and the offsets are relative offset to
the current position in the number of elements. However, update_nd_offset
is to update the start point of a 2D block, so its offset constains two
elements representing the shift in each dimension. update_offset
is to
update the offset per work-item, so its offsets contains values representing
shifts for each work-item.
Example:
```mlir
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
%2 = xegpu.update_offset %1, %off :
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>>, vector<4xindex>
```
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | fixed-length vector of index values |
Results: ¶
Result | Description |
---|---|
result | TensorDesc describing regions of interested data. |
Attributes ¶
BlockTensorDescAttr ¶
A composite attribute for TensorDescType
Syntax:
#xegpu.block_tdesc_attr<
MemorySpaceAttr, # memory_space
IntegerAttr, # array_length
BoolAttr # boundary_check
>
BlockTensorDesc
(or block_tdesc_attr
) is a composite
attribute defined for TensorDescType
for describing following
properties of a TensorDesc
.
1. memory_space
: It describes where the data block described by the
TensorDesc is located, Global
device memory or Shared
local memory.
It is default to Global
.
2. array_length
: It describes how many horizontally consecutive blocks
will be loaded by a hardware load instruction. If the TensorDesc shape
is 8x16, with array_length = 2. The loaded block shape will be actually
8x32. Its default value is 1.
3. boundary_check
: It is used to indicates the hardware whether to do
out-of-boundary check. The default value is true.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memory_space | MemorySpaceAttr | Data memory location |
array_length | IntegerAttr | Number of continuous blocks to load |
boundary_check | BoolAttr | Checking the out of boundary access |
CachePolicyAttr ¶
Describe the cache settings for prefetch/load/store operators
Syntax:
#xegpu.cache_hint<
::mlir::xegpu::CachePolicy # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::CachePolicy | an enum of type CachePolicy |
FenceScopeAttr ¶
Describes the scope of fence. “workgroup” means that the scope is within each work group. “gpu” means the scope is across work groups within the gpu.
Syntax:
#xegpu.fence_scope<
::mlir::xegpu::FenceScope # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::FenceScope | an enum of type FenceScope |
LayoutAttr ¶
Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor.
Syntax:
#xegpu.layout<
DenseI32ArrayAttr, # sg_layout
DenseI32ArrayAttr, # sg_data
DenseI32ArrayAttr, # inst_data
DenseI32ArrayAttr, # lane_layout
DenseI32ArrayAttr, # lane_data
DenseI32ArrayAttr # order
>
XeGPU operations use LayoutAttr
to define how data is distributed across subgroups and work-items.
This attribute is specified in tensor descriptors during tensor description creation. LayoutAttr
includes the following parameters:
sg_layout
: Specifies the total number of subgroups and their layout within a workgroup. It is mandatory for workgroup-level programming. Its presence implies workgroup-level code.sg_data
: Defines the data size accessed per subgroup. It is optionally used withsg_layout
for workgroup-level programming. When it is left empty, the size accessed per subgroup can be derived from the tensor shape andsg_layout
using the formula:sg_data[i] = tensor_shape[i] / sg_layout[i]
.inst_data
: Specifies the data size that is processed by an instruction. It is optionally used with lane_layout. When it is left empty, the data size per instruction is equivalent to the sg_data for workgroup-level programming or equivalent to tensor shape for subgroup-level programming.lane_layout
: Specifies the total number of work-items and their arrangement within a subgroup. It is mandatory for subgroup-level programming and optional for workgroup-level programming.lane_data
: Specifies the shape of the tensor fragment that each lane accesses. It defines a single, minimal distribution unit. Processing the entire tensor may require one or more distribution units per hardware instruction.order
: Specifies the dimension order used to linearize n-dimensional sg_layout and lane_layout to 1-dimensional layout. The first dimension in the order list is the fastest-changing dimension. If it is not present, the default value is [1, 0].
Examples:
- Subgroup level layout:
#xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>
In this example, there are 16 work-items per subgroup, and is organized as [[0, 1, 2, .., 7],[8, 9, .., 15]]. The distribution unit is 1x1.
- Subgroup level layout with order:
#xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1], order = [0, 1]>
In this example, there are 16 work-items per subgroup, and is organized as [[0, 2, 4, …, 14], [1, 3, 5, …, 15]]. The distribution unit is 1x1.
- Subgroup level layout with inst_data
#xegpu.layout<inst_data = [8, 16], lane_layout = [2, 8], lane_data = [2, 2]>
In this example, the original problem size is partitioned into smaller subproblems of dimensions [8, 16], which are then distributed among 16 work-items arranged as [[0, 1, 2, …, 7], [8, 9, …, 15]]. Each work-item is assigned four 2x2 blocks in a round-robin manner.
- Workgroup level layout:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [2, 8], lane_data = [1, 1]>
In this example, the layout represents a workgroup distribution. A workgroup consists of 8 subgroups arranged as [[0, 1, 2, 3], [4, 5, 6, 7]]. Each subgroup accesses a 16x16 block per instruction, which is further distributed to 16 work items which is organized as [[0, 1, 2, .., 7],[8, 9, .., 15]].
- Workgroup level layout with order:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [2, 8], lane_data = [1, 1], order = [0, 1]>
In this example, the layout represents a workgroup distribution. A workgroup consists of 8 subgroups arranged as [[0, 2, 4, 6], [1, 3, 5, 7]]. Each subgroup accesses a 16x16 block per instruction, which is further distributed to 16 work items which is organized as [[0, 2, 4, …, 14], [1, 3, 5, …, 15]].
- Workgroup level layout with inst_data:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], inst_data = [8, 16], lane_layout = [2, 8], lane_data = [1, 1]>
This example is similar to the previous ones, but the inst_data
parameter divides sg_data
into two instructions,
each processing an 8x16 block. These blocks are further distributed across 16 work-items with a distribution unit of 1x1.
Unlike the 2x2 distribution unit in example 3, which results in accessing contiguous 2x2 blocks, the 1x1 distribution
unit may result in non-contiguous access.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
sg_layout | DenseI32ArrayAttr | |
sg_data | DenseI32ArrayAttr | |
inst_data | DenseI32ArrayAttr | |
lane_layout | DenseI32ArrayAttr | |
lane_data | DenseI32ArrayAttr | |
order | DenseI32ArrayAttr |
MemLayoutAttr ¶
Specifies memory layouts with named attributes.
This attribute stores a collection of named attributes that describe memory layout properties such as stride, block, etc.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
attrs | DictionaryAttr |
MemorySpaceAttr ¶
Describe the location of data described by a TensorDesc
:
Global device memory (Global
) or Shared local memory (SLM
).
Syntax:
#xegpu.memory_space<
::mlir::xegpu::MemorySpace # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::MemorySpace | an enum of type MemorySpace |
RangeAttr ¶
Specifies a half-open range
Syntax:
#xegpu.range<
IntegerAttr, # start
IntegerAttr # end
>
RangeAttr
is an attribute that defines a half-open range [start, end).
The range is inclusive of the start value and exclusive of the end value.
One usage of this attribute can be to specify the subgroup id range.
The subgroup id range can be specified using this attribute,
and it can be attached to a scf.if op like
scf.if %cond {
// some operations
} {sg_id_range = #xegpu.range<[2, 4]>}
In this case, the scf.if op will only be executed for subgroup IDs 2 and 3.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
start | IntegerAttr | |
end | IntegerAttr |
ScatterTensorDescAttr ¶
A composite attribute for TensorDescType
Syntax:
#xegpu.scatter_tdesc_attr<
MemorySpaceAttr, # memory_space
IntegerAttr # chunk_size
>
ScatterTensorDesc
is a composite attribute defined for TensorDescType
for describing following properties of a TensorDesc
:
memory_space
: It describes where the data block described by the TensorDesc is located,Global
device memory orShared
local memory. It is default toGlobal
.chunk_size
: Specifies the number of contiguous elements accessed per offset. The default value is 1.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memory_space | MemorySpaceAttr | Data memory location |
chunk_size | IntegerAttr | Number of contiguous elements |
SliceAttr ¶
Describes the data distribution and sharing among subgroups or work-items.
Syntax:
#xegpu.slice<
xegpu::DistributeLayoutAttr, # parent
DenseI64ArrayAttr # dims
>
Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
However, whereas LayoutAttr requires the data to have the same rank as the attribute,
SliceAttr permits the data to have a lower rank. In this case, compute units in the
specified dimensions (given by $dims
) share the data, provided that the remaining
ranks match the data rank. SliceAttr is commonly used by operations such as
vector.multi_reduction and vector.broadcast.
Example:
#l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
#r = #xegpu.slice<#l, dim = [0]>
%exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
%red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
%bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a single reduction result of type vector<32xf32>.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
parent | xegpu::DistributeLayoutAttr | |
dims | DenseI64ArrayAttr |
Types ¶
MemDescType ¶
MemDesc describing the data in SLM
MemDesc represents a block of data stored in shared local memory. By default, unless a layout attribute is provided, the data is stored contiguously in row-major order within the region.
Examples:
// A multi-dimensional array stored in column-major order.
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128]>>
// A multi-dimensional array stored in a blocked layout. Elements within the same block
// are stored contiguously in memory. Blocks are stored in row-major order.
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<block = [8, 8]>>
// A multi-dimensional array stored in column-major order with blocked layout.
!xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [8, 8]>>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
shape | ::llvm::ArrayRef<int64_t> | |
elementType | mlir::Type | |
mem_layout | MemLayoutAttr |
NbarrierType ¶
!xegpu.nbarrier a custom XeGPU type representing a barrier.
Syntax: !xegpu.nbarrier
TensorDescType ¶
TensorDesc describing regions of interested data.
TensorDesc is a type designed to describe regions of interest in data, as well as some features unique to Intel hardware. Unlike the built-in tensor type in MLIR, it essentially contains only metadata and does not hold the data itself. It is primarily designed to support 2D block load/store and DPAS (matrix multiplication instruction) on Intel GPUs. It encodes the following information:
shape: the sizes/shape of the interested data block, e.g., 8x16 means 8 rows and each row contains 16 contiguous data element. The rows could be either contiguous or not, depends on the encoding attribute. If the encoding is a BlockTensorDescAttr, rows are contiguous. If the encoding is a ScatterTensorDescAttr, rows are not necessary to be contiguous. If encoding is not set, it is considered as a default BlockTensorDescAttr.
element_type: the data type of the data element, e.g., f16, f32.
Similar to the built-in tensor, it also provides optional attributes for encoding additional information via either BlockTensorDescAttr or ScatterTensorDescAttr, or supporting Workgroup, Subgroup, and workitem (or SIMT) level programmings via the Layout attribute. Please check their definition for details.
Syntax:
TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, encoding-attr)? (, layout-attr)?
enconding-attr = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
layout-attr = (, layout `<`sg_layout = value, sg_data = value, inst_data = value, lane_layout = value, lane_data = value, order = value`>`)?
Examples:
// A block TensorDesc with 8x16 i32 elements
xegpu.tensor_desc<8x16xi32>
// A block TensorDesc with 8x16 f32 elements
xegpu.tensor_desc<8x16xf32>
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
// A 1D TensorDesc with a layout for subgroup level programming, each lane access two continuous elements
xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [2]>>
// A 1D TensorDesc with a layout for subgroup level programming, each lane access two elements with stride = 16
xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
// A TensorDesc with a layout for subgroup level programming
xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// A TensorDesc with a layout for workgroup level programming
xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
// A TensorDesc with a layout for workgroup level programming without lane_layout and lane_data
xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16]>>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
shape | ::llvm::ArrayRef<int64_t> | |
elementType | mlir::Type | |
encoding | mlir::Attribute | |
layout | mlir::Attribute |
Enums ¶
CmpFPredicate ¶
Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
Cases: ¶
Symbol | Value | String |
---|---|---|
AlwaysFalse | 0 | false |
OEQ | 1 | oeq |
OGT | 2 | ogt |
OGE | 3 | oge |
OLT | 4 | olt |
OLE | 5 | ole |
ONE | 6 | one |
ORD | 7 | ord |
UEQ | 8 | ueq |
UGT | 9 | ugt |
UGE | 10 | uge |
ULT | 11 | ult |
ULE | 12 | ule |
UNE | 13 | une |
UNO | 14 | uno |
AlwaysTrue | 15 | true |
CmpIPredicate ¶
Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
Cases: ¶
Symbol | Value | String |
---|---|---|
eq | 0 | eq |
ne | 1 | ne |
slt | 2 | slt |
sle | 3 | sle |
sgt | 4 | sgt |
sge | 5 | sge |
ult | 6 | ult |
ule | 7 | ule |
ugt | 8 | ugt |
uge | 9 | uge |
IntegerOverflowFlags ¶
Integer overflow arith flags
Cases: ¶
Symbol | Value | String |
---|---|---|
none | 0 | none |
nsw | 1 | nsw |
nuw | 2 | nuw |
RoundingMode ¶
Floating point rounding mode
Cases: ¶
Symbol | Value | String |
---|---|---|
to_nearest_even | 0 | to_nearest_even |
downward | 1 | downward |
upward | 2 | upward |
toward_zero | 3 | toward_zero |
to_nearest_away | 4 | to_nearest_away |
AtomicRMWKind ¶
Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
Cases: ¶
Symbol | Value | String |
---|---|---|
addf | 0 | addf |
addi | 1 | addi |
andi | 2 | andi |
assign | 3 | assign |
maximumf | 4 | maximumf |
maxnumf | 5 | maxnumf |
maxs | 6 | maxs |
maxu | 7 | maxu |
minimumf | 8 | minimumf |
minnumf | 9 | minnumf |
mins | 10 | mins |
minu | 11 | minu |
mulf | 12 | mulf |
muli | 13 | muli |
ori | 14 | ori |
xori | 15 | xori |
FastMathFlags ¶
Floating point fast math flags
Cases: ¶
Symbol | Value | String |
---|---|---|
none | 0 | none |
reassoc | 1 | reassoc |
nnan | 2 | nnan |
ninf | 4 | ninf |
nsz | 8 | nsz |
arcp | 16 | arcp |
contract | 32 | contract |
afn | 64 | afn |
fast | 127 | fast |
CachePolicy ¶
Cache policy
Cases: ¶
Symbol | Value | String |
---|---|---|
CACHED | 0 | cached |
UNCACHED | 1 | uncached |
STREAMING | 2 | streaming |
READ_INVALIDATE | 3 | read_invalidate |
WRITE_BACK | 4 | write_back |
WRITE_THROUGH | 5 | write_through |
FenceScope ¶
The enumeration for the scope of fence operation.
Cases: ¶
Symbol | Value | String |
---|---|---|
Workgroup | 0 | workgroup |
GPU | 1 | gpu |
MemorySpace ¶
The address space of the memory the tensor descritor is created for
Cases: ¶
Symbol | Value | String |
---|---|---|
Global | 0 | global |
SLM | 3 | slm |