'xegpu' Dialect
The XeGPU dialect that models Intel GPU’s ISA
The XeGPU dialect closely models a subset of the Xe GPU’s ISA, providing an abstraction to support high-performance GEMM code generation. It serves as a bridge dialect in the MLIR gradual lowering process, working with MLIR memref and vector types, and complements the Arith, Math, Vector, and Memref dialects. XeGPU operations are introduced for special Xe instructions not modeled by the LLVM/SPIR-V dialect, such as DPAS and 2D block load and store.
It supports a tile-based programming model, decomposing the GEMM kernel into large predefined tile sizes at the subgroup and workgroup levels. XeGPU allows the high-level GEMM algorithm to be easily expressed. Underneath, it uses target-specific recipes and hardware features to achieve optimal performance on specific hardware. By decomposing GEMM at submatrix granularity and mapping it to registers, it naturally supports optimizations like fusing with neighboring operations.
Operations ¶
xegpu.alloc_nbarrier
(xegpu::AllocNbarrierOp) ¶
It allocates a set of named barriers.
Syntax:
operation ::= `xegpu.alloc_nbarrier` $nbarrier_num attr-dict
AllocNbarrier is to create a set of named barriers as
specified by nbarrier_num
. Named barriers are workgroup level resources,
and are shared by all threads in the workgroup. For example, there are
up to 32 barriers (range 0-31) for each XeCore on PVC. A typical use case
is that a workgroup is partitioned into N subgroups of threads (N <= 32),
and each subgroup coordinating their work with a separate barrier with id
range from 0 to N respectively.
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
nbarrier_num | ::mlir::IntegerAttr | 64-bit signless integer attribute |
xegpu.atomic_rmw
(xegpu::AtomicRMWOp) ¶
Atomic read-modify-write operation on the TensorDesc.
Syntax:
operation ::= `xegpu.atomic_rmw` $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
The xegpu.atomic_rmw
operation provides a way to perform a read-modify-write
operation on the region described by the TensorDesc
free from data races. The
kind
enumeration specifies the modification to be performed, The mask
operand
has the same shape with TensorDesc
, and is used to enable or disable specific
data points of the TensorDesc
. The value
operand represents the new value to
be applied during the modification.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, MemoryEffectOpInterface (MemoryEffectOpInterface)
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource, MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
, MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
kind | ::mlir::arith::AtomicRMWKindAttr | allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 |
Operands: ¶
Operand | Description |
---|---|
tensorDesc | TensorDesc describing regions of interested data. |
mask | vector of 1-bit signless integer values of ranks 1 or 1-bit signless integer |
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
Results: ¶
Result | Description |
---|---|
result | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.convert_layout
(xegpu::ConvertLayoutOp) ¶
Convert the layout of the input operand
Syntax:
operation ::= `xegpu.convert_layout` $source attr-dict `:` type($source)
convert_layout
adjusts the data distribution across subgroups and/or work-items by modifying
the LayoutAttr
. Both srcMap
and resMap
must correspond to the same programming scope, such
as workgroup-level (wg) or subgroup-level (sg) code. This operation is not valid once the IR is
lowered to WI level because that is the end result of all distributions.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
srcMap | ::mlir::xegpu::LayoutAttr | Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor.
|
resMap | ::mlir::xegpu::LayoutAttr | Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor.
|
Operands: ¶
Operand | Description |
---|---|
source | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2 |
Results: ¶
Result | Description |
---|---|
result | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2 |
xegpu.create_nd_tdesc
(xegpu::CreateNdDescOp) ¶
Create nd-tensor descriptor operation
Syntax:
operation ::= `xegpu.create_nd_tdesc` $source ``
custom<DynamicIndexList>($offsets, $const_offsets)
(`,` custom<DynamicIndexList>($shape, $const_shape)^
`,` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
The “create_nd_tdesc” operation creates a TensorDescType which represents a sub-view of a 1D/2D memory region inside the one or two innermost dimensions of the source. (It can be extended to support n-D memory region if needed in future). Elements in the subview continuous in each dimension. It encodes the following important information for supporting Intel hardware features:
source: an object representing (starting address/pointer of) a memory region. It can be either a memref object, or simply a pointer represented by uint64_t type. For the case of dynamic memrefs or pointer, the shape and layout information of the memory region should be explicitly passed via
shape
andstrides
parameters.offsets: index values represents offsets from the “source” at the each dimension at which the subview of the target memory will be created. It is encoded via “offsets” and “const_offsets”, such that it can accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
shape: the shape information of the memory region pointed by the “source”. It is typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the shape information has to be explicitly passed via the “shape” and “const_shape” arguments.
strides: the strides of the memory region pointed by the “source”. Similar to shape, it is typically encoded via the MemRefType of the source too. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the strides information has to be explicitly passed via the “strides” and “const_strides” argument.
In SIMT mode, tensor descriptor is augmented with LayoutAttr
which describes the
mapping of the tensor descriptor to the work items.
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = memref.alloc() : memref<1024x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>
Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = memref.alloc(%h, %w) : memref<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = ... : ui64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
Example 4 (SIMT mode):
%0 = memref.alloc() : memref<1024x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 8 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0] : memref<1024x1024xf32>
-> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OffsetSizeAndStrideOpInterface
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
const_shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
const_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
source | non-0-ranked.memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | variadic of index |
shape | variadic of index |
strides | variadic of index |
Results: ¶
Result | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.create_tdesc
(xegpu::CreateDescOp) ¶
Create scattered tensor descriptors (TensorDesc).
Syntax:
operation ::= `xegpu.create_tdesc` $source `,` $offsets attr-dict `:` type($source) `,` type($offsets) `->` qualified(type($TensorDesc))
“create_tdesc” is similar to “create_nd_tdesc” in terms that it creates a Tensor Descriptor (TensorDescType) for a memory region. While “create_nd_tdesc” is for creating continuous subviews, “create_tdesc” is for creating non-continuous (scattered) subviews, allowing each work-item in a subgroup specifying their own offset. It accepts the following parameters:
- source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
- offsets: a vector containing offsets of each access point. Its size is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, implying each element in the vector corresponds to a work-item (SIMT lane) in the subgroup.
The first dimension of the result TensorDesc corresponds to work-items, so it should match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1.
In SIMT mode, similar to create_nd_tdesc
the resulting tensor descriptor is augmented
with LayoutAttr
which describes the mapping of the tensor descriptor to the work items.
In this case, the first dimension of the tensor descriptor represents the work-items, and
the second dimension represents the chunk size.
Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>
Example 2: It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
Example 3: It is similar to Example 2, but there is some overlaps among workitems. It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
Example 4: SIMT mode
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>,
#xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
source | non-0-ranked.memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | vector of index values of ranks 1 |
Results: ¶
Result | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.dpas
(xegpu::DpasOp) ¶
It performs mma computation
Syntax:
operation ::= `xegpu.dpas` $lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)? `->` type($result)
DPAS performs matrix multiplication on matrix A of mxk
size, B of kxn
size, and accumulate on matrix C of mxn
to the same size
matrix , m=8
, n=16
and k=8 * 32/bit_width_of_elem_type
. So for fp16
data type, the matrices are A: vector<8x16xf16>
, B: vector<16x16xf16>
,
and C/D: vector<8x16xf32>
. Besides the matrix size requirements, DPAS
also requires A and B to be loaded with the required data layout. Specially,
VNNI layout is required for B operand. It is achieved via adding packed
attribute to the load_nd
operator. Due to the VNNI transformation, B operands
can be represented as a 3D vector, with the last dimension representing the VNNI
factor, which is computed as 32/bit_width_of_elem_type
. Thus, B: vector<16x16xf16>
can be represented as B: vector<8x16x2xf16>
.
In SIMT code, each work-item from a subgroup holds a data fragment for A, B, C and the result,
which are represented as 1D vectors. Please refer to [OpenCL Intel extentions]
(https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html)
for more details about the fragment distribution.
Note: on PVC, the hardware can perform load with VNNI transformation when data
element type is 16-bit or lower precision, taking 2 or 4 elements from
the first dimension and inserted into the newly added innermost dimension.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
lhs | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3 |
rhs | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3 |
acc | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2 |
Results: ¶
Result | Description |
---|---|
result | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2 |
xegpu.fence
(xegpu::FenceOp) ¶
It synchronizes memory accesses.
Syntax:
operation ::= `xegpu.fence` `memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict
It synchronizes the memory access between
write and following read or write.
1. Memory_kind
describes the memory kind. “global” means the global memory,
“slm” means the share local memory.
2. Fence_scope
describes the scope of fence. “Workgroup” means that the scope would be
within each workgroup. “GPU” means the scope would be across workgroups within the GPU.
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
memory_kind | ::mlir::xegpu::MemorySpaceAttr | Describe the location of data described by a `TensorDesc`: Global device memory (`Global`) or Shared local memory (`SLM`). |
fence_scope | ::mlir::xegpu::FenceScopeAttr | Describes the scope of fence. "workgroup" means that the scope is within each work group. "gpu" means the scope is across work groups within the gpu. |
xegpu.init_nbarrier
(xegpu::InitNbarrierOp) ¶
It assigns a named barrier to the current thread.
Syntax:
operation ::= `xegpu.init_nbarrier` $nbarrier_id `,` $participant_thread_num attr-dict `:`
type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))
InitNbarrierOp assigns the named barrier with the specified
barrier ID (0~31) to the current thread. Multiple threads may bind to the
same named barrier, and the participant_thread_num
specifies the total
number of threads associated with the nbarrier. It returns an object of
NbarrierType representing the barrier
Operands: ¶
Operand | Description |
---|---|
nbarrier_id | 8-bit signless integer |
participant_thread_num | 8-bit signless integer |
Results: ¶
Result | Description |
---|---|
result | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.load
(xegpu::LoadGatherOp) ¶
Load a set of scattered data points from memory.
Syntax:
operation ::= `xegpu.load` $TensorDesc `,` $mask prop-dict attr-dict
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)
It (aka. load) load data per each work-item. The output describes the data being loaded at the subgroup level, so its size is consistent with the number of work-items in a subgroup. When the chunk size is larger than 2, the output vector is a 2D vector, with dim-1 correspoding to work-items, and dim-0 corresponding to the chunk size loaded by each work-item. Specially, there is a transpose effect on the result (as compared to the TensorDesc) due to the hardware implementation. Therefore, a transpose attribute is introduced on purpose, making sure users are aware of this implicit transformation.
The mask operand masks out memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
In SIMT mode, LoadGatherOp expects the tensor descriptor to be augmented with LayoutAttr
which describes the mapping of the tensor to the work items. In this case, result vector
represents the data to be loaded by each work-item. Each work-item recieves a chunk_size
number of elements.
Example 1:
%2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
vector<16xi1> -> vector<16xf32>
Example 2:
%2 = xegpu.load %1, %0 {transpose,
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<8x16xf32>
Example 3 (SIMT mode):
%2 = xegpu.load %1, %0 {transpose,
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
vector<16xi1> -> vector<8x1xf32>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
transpose | ::mlir::UnitAttr | unit attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
mask | vector of 1-bit signless integer values of ranks 1 or 1-bit signless integer |
Results: ¶
Result | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.load_nd
(xegpu::LoadNdOp) ¶
Loads a n-D block from memory (represented by TensorDesc)to registers (represented by vector)
Syntax:
operation ::= `xegpu.load_nd` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
LoadNdOp essentially mimics the hardware block read instruction to read a block of data from memory to register. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. VNNI transformation is an hardware feature for Intel GPU, which is used to do data packing during the load for B operand of matrix operation, if the bit width of the data type is less then 32 bits, e.g., fp16. And transpose is another Intel hardware feature, which will do transpose operation when loading the data if the bit width of the data type is fp32 or fp64. It implies that vnni and transpose cannot exit at the same time.
In SIMT mode, LoadNdOp expects the tensor descriptor to be augmented with LayoutAttr
which describes the mapping of the tensor to the work items. In this case, result
vector represents the data to be loaded by each work-item.
Example 1:
xegpu.load_nd %1 {transpose = [1, 0],
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<streaming>}
: !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
Example 2 (SIMT mode):
xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<8x16xf32,
#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x1xf32>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
packed | ::mlir::UnitAttr | unit attribute |
transpose | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
Results: ¶
Result | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.nbarrier_arrive
(xegpu::NbarrierArriveOp) ¶
It signals the arrival at the named barrier.
Syntax:
operation ::= `xegpu.nbarrier_arrive` $nbarrier attr-dict `:` qualified(type($nbarrier))
NbarrierArriveOp signals the hardware (or other threads)
that the current thread has produced its data for the consumer threads. When
the hardware signalled by participant_thread_num
threads for the named barrier,
it will notify the threads waiting for the named barrier to continue their work.
Operands: ¶
Operand | Description |
---|---|
nbarrier | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.nbarrier_wait
(xegpu::NbarrierWaitOp) ¶
It waits for a named barrier.
Syntax:
operation ::= `xegpu.nbarrier_wait` $nbarrier attr-dict `:` qualified(type($nbarrier))
NbarrierWaitOp signals the hardware which named barrier the current thread is waiting for, such that it can get notified when the named barrier is completed.
Operands: ¶
Operand | Description |
---|---|
nbarrier | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.prefetch
(xegpu::PrefetchOp) ¶
Prefetches a set of scattered data points to cache
Syntax:
operation ::= `xegpu.prefetch` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))
It issues instructions to prefetch a set of scattered data points from memory to each level of the cache based on their cache policy. As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead.
Example:
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.prefetch_nd
(xegpu::PrefetchNdOp) ¶
Prefetches a n-D block to cache
Syntax:
operation ::= `xegpu.prefetch_nd` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))
It issues an instruction to prefetch a block of data from continuous memory regions to each level of the cache based on their cache policy.
Example:
xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<8x16xf16>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.store
(xegpu::StoreScatterOp) ¶
Store data to scattered memory locations.
Syntax:
operation ::= `xegpu.store` $value `,` $TensorDesc `,` $mask prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)
It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
and the dim-0 of the value corresponds to the chunk size stored per lane. So store_scatter
has transpose effect, which is similar to load_gather
. Therefore, a transpose attribute is
introduced on purpose, making sure users are aware of this implicit transformation.
In SIMT mode, StoreScatterOp expects the tensor descriptor to be augmented with LayoutAttr
which describes the mapping of the tensor to the work items. In this case, input vector
represents the data to be stored by each work-item. Each work-item recieves a chunk_size
number of elements.
Example 1:
xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
Example 2:
xegpu.store %0, %1, %2 {transpose,
l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
Example 3 (SIMT mode):
xegpu.store %0, %1, %2 {transpose,
l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
!xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
transpose | ::mlir::UnitAttr | unit attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
TensorDesc | TensorDesc describing regions of interested data. |
mask | vector of 1-bit signless integer values of ranks 1 or 1-bit signless integer |
xegpu.store_nd
(xegpu::StoreNdOp) ¶
Stores a n-D block register region back to memory, currently only supports 2D
Syntax:
operation ::= `xegpu.store_nd` $value `,` $TensorDesc prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc))
StoreNdOp essentially mimics the hardware block write instruction io write a block of data from register into the memory region as described by the TensorDesc. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked.
In SIMT mode, StoreNdOp expects the tensor descriptor to be augmented with LayoutAttr
which describes the mapping of the tensor to the work items. In this case, input
vector represents the data to be stored by each work-item.
Example 1:
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
Example 2 (SIMT mode):
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x1xf16>, !xegpu.tensor_desc<8x16xf16,
#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operators |
Operands: ¶
Operand | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.update_nd_offset
(xegpu::UpdateNdOffsetOp) ¶
It updates the offsets for the TensorDesc.
Syntax:
operation ::= `xegpu.update_nd_offset` $TensorDesc `,`
custom<DynamicIndexList>($offsets, $const_offsets)
attr-dict `:` qualified(type($result))
The op updates the offset of the given TensorDesc. The offsets are relative offset to the current position in the number of elements. It will result in a same type TensorDesc as the input.
Example 1:
%2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
Example 2 (SIMT mode):
%2 = xegpu.update_nd_offset %1, [0, 16]:
!xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | variadic of index |
Results: ¶
Result | Description |
---|---|
result | TensorDesc describing regions of interested data. |
xegpu.update_offset
(xegpu::UpdateOffsetOp) ¶
It updates the offsets for the given tensor descriptor
Syntax:
operation ::= `xegpu.update_offset` $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
It behaves similar to update_nd_offset
in terms that
it updates offset of a TensorDesc, and the offsets are relative offset to
the current position in the number of elements. However, update_nd_offset
is to update the start point of a 2D block, so its offset constains two
elements representing the shift in each dimension. update_offset
is to
update the offset per work-item, so its offsets contains values representing
shifts for each work-item.
Example 1:
```mlir
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
%2 = xegpu.update_offset %1, %off :
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>>, vector<4xindex>
```
Example 2 (SIMT mode):
```mlir
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
%2 = xegpu.update_offset %1, %off :
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>,
#xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xindex>
```
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | vector of index values of ranks 1 |
Results: ¶
Result | Description |
---|---|
result | TensorDesc describing regions of interested data. |
Attributes ¶
BlockTensorDescAttr ¶
A composite attribute for TensorDescType
Syntax:
#xegpu.block_tdesc_attr<
MemorySpaceAttr, # memory_space
IntegerAttr, # array_length
BoolAttr # boundary_check
>
BlockTensorDesc
(or block_tdesc_attr
) is a composite
attribute defined for TensorDescType
for describing following
properties of a TensorDesc
.
1. memory_space
: It describes where the data block described by the
TensorDesc is located, Global
device memory or Shared
local memory.
It is default to Global
.
2. array_length
: It describes how many horizontally consecutive blocks
will be loaded by a hardware load instruction. If the TensorDesc shape
is 8x16, with array_length = 2. The loaded block shape will be actually
8x32. Its default value is 1.
3. boundary_check
: It is used to indicates the hardware whether to do
out-of-boundary check. The default value is true.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memory_space | MemorySpaceAttr | |
array_length | IntegerAttr | 1 |
boundary_check | BoolAttr | true |
CachePolicyAttr ¶
Describe the cache settings for prefetch/load/store operators
Syntax:
#xegpu.cache_hint<
::mlir::xegpu::CachePolicy # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::CachePolicy | an enum of type CachePolicy |
FenceScopeAttr ¶
Describes the scope of fence. “workgroup” means that the scope is within each work group. “gpu” means the scope is across work groups within the gpu.
Syntax:
#xegpu.fence_scope<
::mlir::xegpu::FenceScope # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::FenceScope | an enum of type FenceScope |
LayoutAttr ¶
Describes the data distribution to subgroups and work-items for a tensor specified by the tensor descriptor.
Syntax:
#xegpu.layout<
DenseI32ArrayAttr, # sg_layout
DenseI32ArrayAttr, # sg_data
DenseI32ArrayAttr, # inst_data
DenseI32ArrayAttr, # lane_layout
DenseI32ArrayAttr, # lane_data
DenseI32ArrayAttr # order
>
XeGPU operations use LayoutAttr
to define how data is distributed across subgroups and work-items.
This attribute is specified in tensor descriptors during tensor description creation. LayoutAttr
includes the following parameters:
sg_layout
: Specifies the total number of subgroups and their layout within a workgroup. It is mandatory for workgroup-level programming. Its presence implies workgroup-level code.sg_data
: Defines the data size accessed per subgroup. It is optionally used withsg_layout
for workgroup-level programming. When it is left empty, the size accessed per subgroup can be derived from the tensor shape andsg_layout
using the formula:sg_data[i] = tensor_shape[i] / sg_layout[i]
.inst_data
: Specifies the data size that is processed by an instruction. It is optionally used with lane_layout. When it is left empty, the data size per instruction is equivalent to the sg_data for workgroup-level programming or equivalent to tensor shape for subgroup-level programming.lane_layout
: Specifies the total number of work-items and their arrangement within a subgroup. It is mandatory for subgroup-level programming and optional for workgroup-level programming.lane_data
: Specifies the shape of the tensor fragment that each lane accesses. It defines a single, minimal distribution unit. Processing the entire tensor may require one or more distribution units per hardware instruction.order
: Specifies the dimension order used to linearize n-dimensional sg_layout and lane_layout to 1-dimensional layout. The first dimension in the order list is the fastest-changing dimension. If it is not present, the default value is [1, 0].
Examples:
- Subgroup level layout:
#xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>
In this example, there are 16 work-items per subgroup, and is organized as [[0, 1, 2, .., 7],[8, 9, .., 15]]. The distribution unit is 1x1.
- Subgroup level layout with order:
#xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1], order = [0, 1]>
In this example, there are 16 work-items per subgroup, and is organized as [[0, 2, 4, …, 14], [1, 3, 5, …, 15]]. The distribution unit is 1x1.
- Subgroup level layout with inst_data
#xegpu.layout<inst_data = [8, 16], lane_layout = [2, 8], lane_data = [2, 2]>
In this example, the original problem size is partitioned into smaller subproblems of dimensions [8, 16], which are then distributed among 16 work-items arranged as [[0, 1, 2, …, 7], [8, 9, …, 15]]. Each work-item is assigned four 2x2 blocks in a round-robin manner.
- Workgroup level layout:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [2, 8], lane_data = [1, 1]>
In this example, the layout represents a workgroup distribution. A workgroup consists of 8 subgroups arranged as [[0, 1, 2, 3], [4, 5, 6, 7]]. Each subgroup accesses a 16x16 block per instruction, which is further distributed to 16 work items which is organized as [[0, 1, 2, .., 7],[8, 9, .., 15]].
- Workgroup level layout with order:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [2, 8], lane_data = [1, 1], order = [0, 1]>
In this example, the layout represents a workgroup distribution. A workgroup consists of 8 subgroups arranged as [[0, 2, 4, 6], [1, 3, 5, 7]]. Each subgroup accesses a 16x16 block per instruction, which is further distributed to 16 work items which is organized as [[0, 2, 4, …, 14], [1, 3, 5, …, 15]].
- Workgroup level layout with inst_data:
#xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], inst_data = [8, 16], lane_layout = [2, 8], lane_data = [1, 1]>
This example is similar to the previous ones, but the inst_data
parameter divides sg_data
into two instructions,
each processing an 8x16 block. These blocks are further distributed across 16 work-items with a distribution unit of 1x1.
Unlike the 2x2 distribution unit in example 3, which results in accessing contiguous 2x2 blocks, the 1x1 distribution
unit may result in non-contiguous access.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
sg_layout | DenseI32ArrayAttr | |
sg_data | DenseI32ArrayAttr | |
inst_data | DenseI32ArrayAttr | |
lane_layout | DenseI32ArrayAttr | |
lane_data | DenseI32ArrayAttr | |
order | DenseI32ArrayAttr |
MemorySpaceAttr ¶
Describe the location of data described by a TensorDesc
:
Global device memory (Global
) or Shared local memory (SLM
).
Syntax:
#xegpu.memory_space<
::mlir::xegpu::MemorySpace # value
>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::MemorySpace | an enum of type MemorySpace |
ScatterTensorDescAttr ¶
A composite attribute for TensorDescType
Syntax:
#xegpu.scatter_tdesc_attr<
MemorySpaceAttr, # memory_space
IntegerAttr # chunk_size
>
ScatterTensorDesc
is a composite attribute defined for TensorDescType
for describing following properties of a TensorDesc
:
memory_space
: It describes where the data block described by the TensorDesc is located,Global
device memory orShared
local memory. It is default toGlobal
.chunk_size
: indicates number of contiguous elements accessed for each offset, default is 1. It is used withscattered
attr only.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memory_space | MemorySpaceAttr | Data memory location |
chunk_size | IntegerAttr | Number of contiguous elements |
Types ¶
NbarrierType ¶
!xegpu.nbarrier a custom XeGPU type representing a barrier.
Syntax: !xegpu.nbarrier
TensorDescType ¶
TensorDesc describing regions of interested data.
TensorDesc is a type designed to describe regions of interest in data, as well as some features unique to Intel hardware. Unlike the built-in tensor type in MLIR, it essentially contains only metadata and does not hold the data itself. It is primarily designed to support 2D block load/store and DPAS (matrix multiplication instruction) on Intel GPUs. It encodes the following information:
shape: the sizes/shape of the interested data block, e.g., 8x16 means 8 rows and each row contains 16 contiguous data element. The rows could be either contiguous or not, depends on the encoding attribute. If the encoding is a BlockTensorDescAttr, rows are contiguous. If the encoding is a ScatterTensorDescAttr, rows are not necessary to be contiguous. If encoding is not set, it is considered as a default BlockTensorDescAttr.
element_type: the data type of the data element, e.g., f16, f32.
Similar to the built-in tensor, it also provides optional attributes for encoding additional information via either BlockTensorDescAttr or ScatterTensorDescAttr, or supporting Workgroup, Subgroup, and workitem (or SIMT) level programmings via the Layout attribute. Please check their definition for details.
Syntax:
TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, encoding-attr)? (, layout-attr)?
enconding-attr = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
layout-attr = (, layout `<`sg_layout = value, sg_data = value, inst_data = value, lane_layout = value, lane_data = value, order = value`>`)?
Examples:
// A block TensorDesc with 8x16 i32 elements
xegpu.tensor_desc<8x16xi32>
// A block TensorDesc with 8x16 f32 elements
xegpu.tensor_desc<8x16xf32>
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
// A 1D TensorDesc with a layout for subgroup level programming, each lane access two continuous elements
xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [2]>>
// A 1D TensorDesc with a layout for subgroup level programming, each lane access two elements with stride = 16
xegpu.tensor_desc<32xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
// A TensorDesc with a layout for subgroup level programming
xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// A TensorDesc with a layout for workgroup level programming
xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
// A TensorDesc with a layout for workgroup level programming without lane_layout and lane_data
xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [16, 16]>>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
shape | ::llvm::ArrayRef<int64_t> | |
elementType | mlir::Type | |
encoding | mlir::Attribute | |
layout | mlir::Attribute |
Enums ¶
CmpFPredicate ¶
Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
Cases: ¶
Symbol | Value | String |
---|---|---|
AlwaysFalse | 0 | false |
OEQ | 1 | oeq |
OGT | 2 | ogt |
OGE | 3 | oge |
OLT | 4 | olt |
OLE | 5 | ole |
ONE | 6 | one |
ORD | 7 | ord |
UEQ | 8 | ueq |
UGT | 9 | ugt |
UGE | 10 | uge |
ULT | 11 | ult |
ULE | 12 | ule |
UNE | 13 | une |
UNO | 14 | uno |
AlwaysTrue | 15 | true |
CmpIPredicate ¶
Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
Cases: ¶
Symbol | Value | String |
---|---|---|
eq | 0 | eq |
ne | 1 | ne |
slt | 2 | slt |
sle | 3 | sle |
sgt | 4 | sgt |
sge | 5 | sge |
ult | 6 | ult |
ule | 7 | ule |
ugt | 8 | ugt |
uge | 9 | uge |
IntegerOverflowFlags ¶
Integer overflow arith flags
Cases: ¶
Symbol | Value | String |
---|---|---|
none | 0 | none |
nsw | 1 | nsw |
nuw | 2 | nuw |
RoundingMode ¶
Floating point rounding mode
Cases: ¶
Symbol | Value | String |
---|---|---|
to_nearest_even | 0 | to_nearest_even |
downward | 1 | downward |
upward | 2 | upward |
toward_zero | 3 | toward_zero |
to_nearest_away | 4 | to_nearest_away |
AtomicRMWKind ¶
Allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
Cases: ¶
Symbol | Value | String |
---|---|---|
addf | 0 | addf |
addi | 1 | addi |
assign | 2 | assign |
maximumf | 3 | maximumf |
maxs | 4 | maxs |
maxu | 5 | maxu |
minimumf | 6 | minimumf |
mins | 7 | mins |
minu | 8 | minu |
mulf | 9 | mulf |
muli | 10 | muli |
ori | 11 | ori |
andi | 12 | andi |
maxnumf | 13 | maxnumf |
minnumf | 14 | minnumf |
FastMathFlags ¶
Floating point fast math flags
Cases: ¶
Symbol | Value | String |
---|---|---|
none | 0 | none |
reassoc | 1 | reassoc |
nnan | 2 | nnan |
ninf | 4 | ninf |
nsz | 8 | nsz |
arcp | 16 | arcp |
contract | 32 | contract |
afn | 64 | afn |
fast | 127 | fast |
CachePolicy ¶
Cache policy
Cases: ¶
Symbol | Value | String |
---|---|---|
CACHED | 0 | cached |
UNCACHED | 1 | uncached |
STREAMING | 2 | streaming |
READ_INVALIDATE | 3 | read_invalidate |
WRITE_BACK | 4 | write_back |
WRITE_THROUGH | 5 | write_through |
FenceScope ¶
The enumeration for the scope of fence operation.
Cases: ¶
Symbol | Value | String |
---|---|---|
Workgroup | 0 | workgroup |
GPU | 1 | gpu |
MemorySpace ¶
The address space of the memory the tensor descritor is created for
Cases: ¶
Symbol | Value | String |
---|---|---|
Global | 0 | global |
SLM | 3 | slm |