'xegpu' Dialect
The XeGPU dialect that models Intel GPU’s ISA
The XeGPU dialect models Intel Xe ISA semantics but works at vector and TensorDesc data type. It provides 1:1 mappings to match Xe instructions like DPAS and 2D block load. The matrix size being processed at this level exactly matches the hardware instructions or the intrinsic supported by the lower-level GPU compiler.
Operations ¶
xegpu.alloc_nbarrier
(xegpu::AllocNbarrierOp) ¶
It allocates a set of named barriers.
Syntax:
operation ::= `xegpu.alloc_nbarrier` $nbarrier_num attr-dict
AllocNbarrier is to create a set of named barriers as
specified by nbarrier_num
. Named barriers are workgroup level resources,
and are shared by all threads in the workgroup. For example, there are
up to 32 barriers (range 0-31) for each XeCore on PVC. A typical use case
is that a workgroup is partitioned into N subgroups of threads (N <= 32),
and each subgroup coordinating their work with a separate barrier with id
range from 0 to N respectively.
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
nbarrier_num | ::mlir::IntegerAttr | 64-bit signless integer attribute |
xegpu.atomic_rmw
(xegpu::AtomicRMWOp) ¶
Atomic ready-modify-write operation on the TensorDesc.
Syntax:
operation ::= `xegpu.atomic_rmw` $kind $tensorDesc `,` $mask `,` $value attr-dict `:`
qualified(type($tensorDesc)) `,` type($mask) `,` type($value) `->` type($result)
The xegpu.atomic_rmw
operation provides a way to perform a read-modify-write
operation on the region described by the TensorDesc
free from data races. The
kind
enumeration specifies the modification to be performed, The mask
operand
has the same shape with TensorDesc
, and is used to enable or disable specific
data points of the TensorDesc
. The value
operand represents the new value to
be applied during the modification.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
kind | ::mlir::arith::AtomicRMWKindAttr | allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14Enum cases:
|
Operands: ¶
Operand | Description |
---|---|
tensorDesc | TensorDesc describing regions of interested data. |
mask | vector of 1-bit signless integer values of ranks 1 or 1-bit signless integer |
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
Results: ¶
Result | Description |
---|---|
result | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.create_nd_tdesc
(xegpu::CreateNdDescOp) ¶
Create nd-tensor descriptor operation
Syntax:
operation ::= `xegpu.create_nd_tdesc` $source ``
custom<DynamicIndexList>($offsets, $const_offsets)
(`,` custom<DynamicIndexList>($shape, $const_shape)^
`,` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
The “create_nd_tdesc” operation creates a TensorDescType which represents a sub-view of a 1D/2D memory region inside the one or two innermost dimensions of the source. (It can be extended to support n-D memory region if needed in future). Elements in the subview continuous in each dimension. It encodes the following important information for supporting Intel hardware features:
source: an object representing (starting address/pointer of) a memory region. It can be either a memref object, or simply a pointer represented by uint64_t type. For the case of dynamic memrefs or pointer, the shape and layout information of the memory region should be explicitly passed via
shape
andstrides
parameters.offsets: index values represents offsets from the “source” at the each dimension at which the subview of the target memory will be created. It is encoded via “offsets” and “const_offsets”, such that it can accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
shape: the shape information of the memory region pointed by the “source”. It is typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the shape information has to be explicitly passed via the “shape” and “const_shape” arguments.
strides: the strides of the memory region pointed by the “source”. Similar to shape, it is typically encoded via the MemRefType of the source too. But if “source” is simply a pointer represented as uint64_t type, or a memref type without shape information e.g., memref<?x?xf16>, the strides information has to be explicitly passed via the “strides” and “const_strides” argument.
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = memref.alloc() : memref<1024x1024xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0]: memref<1024x1024xf32> -> TensorDesc<8x16xf32>
Example 2 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = memref.alloc(%h, %w) : memref<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: memref<?x?xf32> -> TensorDesc<8x16xf32>
Example 3 (suppose the tensor shape inferred by the compiler is 8x16):
%0 = ... : ui64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
Traits: AlwaysSpeculatableImplTrait
, AttrSizedOperandSegments
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, OffsetSizeAndStrideOpInterface
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
const_shape | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
const_strides | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
source | non-0-ranked.memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | variadic of index |
shape | variadic of index |
strides | variadic of index |
Results: ¶
Result | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.create_tdesc
(xegpu::CreateDescOp) ¶
Create scattered tensor descriptors (TensorDesc).
Syntax:
operation ::= `xegpu.create_tdesc` $source `,` $offsets attr-dict `:` type($source) `,` type($offsets) `->` qualified(type($TensorDesc))
“create_tdesc” is similar to “create_nd_tdesc” in terms that it creates a Tensor Descriptor (TensorDescType) for a memory region. While “create_nd_tdesc” is for creating continuous subviews, “create_tdesc” is for creating non-continuous (scattered) subviews, allowing each work-item in a subgroup specifying their own offset. It accepts the following parameters:
- source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
- offsets: a vector containing offsets of each access point. Its size is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, implying each element in the vector corresponds to a work-item (SIMT lane) in the subgroup.
The first dimension of the result TensorDesc corresponds to work-items, so it should match the dimension of offsets. It may also has a second dimension corresponding to the chunk_size if the chunk size is larger than 1.
Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>
Example 2. It assumes subgroup size is 4, and each workitem access 8 elements. It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
Example 3. It is similar to Example 2, but there is some overlaps among workitems. It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
%0 = memref.alloc() : memref<1024xf32>
%off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
, ViewLikeOpInterface
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
source | non-0-ranked.memref of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values or 64-bit unsigned integer or 32-bit unsigned integer or 64-bit signless integer or 32-bit signless integer |
offsets | vector of index values of ranks 1 |
Results: ¶
Result | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.dpas
(xegpu::DpasOp) ¶
It performs mma computation
Syntax:
operation ::= `xegpu.dpas` $lhs `,` $rhs (`,` $acc^)? attr-dict `:` type($lhs)`,` type($rhs) (`,` type($acc)^)? `->` type($result)
DPAS performs matrix multiplication on matrix A of mxk
size, B of kxn
size, and accumulate on matrix C of mxn
to the same size
matrix , m=8
, n=16
and k=8 * 32/bit_width_of_elem_type
. So for fp16
data type, the matrices are A: vector<8x16xf16>
, B: vector<16x16xf16>
,
and C/D: vector<8x16xf32>
. Besides the matrix size requirements, DPAS
also requires A and B to be loaded with the required data layout. Specially,
VNNI layout is required for B operand. It is achieved via adding `packed`
attribute to the `load_nd` operator. Due to the VNNI transformation, B operands
can be represented as a 3D vector, with the last dimension representing the VNNI
factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
can be represented as `B: vector<8x16x2xf16>`.
Note: on PVC, the hardware can perform load with VNNI transformation when data
element type is 16-bit or lower precision, taking 2 or 4 elements from
the first dimension and inserted into the newly added innermost dimension.
Traits: AlwaysSpeculatableImplTrait
Interfaces: ConditionallySpeculatable
, NoMemoryEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operands: ¶
Operand | Description |
---|---|
lhs | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2/3 |
rhs | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2/3 |
acc | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2 |
Results: ¶
Result | Description |
---|---|
result | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 2 |
xegpu.fence
(xegpu::FenceOp) ¶
It synchronizes memory accesses.
Syntax:
operation ::= `xegpu.fence` `memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict
It synchronizes the memory access between
write and following read or write.
1. Memory_kind
describes the memory kind. “global” means the global memory,
“slm” means the share local memory.
2. Fence_scope
describes the scope of fence. “Workgroup” means that the scope would be
within each workgroup. “GPU” means the scope would be across workgroups within the GPU.
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
memory_kind | ::mlir::xegpu::MemorySpaceAttr | Describe the location of data described by a `TensorDesc`: Global device memory (`Global`) or Shared local memory (`SLM`).Enum cases:
|
fence_scope | ::mlir::xegpu::FenceScopeAttr | Describes the scope of fence. "workgroup" means that the scope is within each work group. "gpu" means the scope is across work groups within the gpu.Enum cases:
|
xegpu.init_nbarrier
(xegpu::InitNbarrierOp) ¶
It assigns a named barrier to the current thread.
Syntax:
operation ::= `xegpu.init_nbarrier` $nbarrier_id `,` $participant_thread_num attr-dict `:`
type($nbarrier_id) `,` type($participant_thread_num) `->` qualified(type($result))
InitNbarrierOp assigns the named barrier with the specified
barrier ID (0~31) to the current thread. Multiple threads may bind to the
same named barrier, and the participant_thread_num
specifies the total
number of threads associated with the nbarrier. It returns an object of
NbarrierType representing the barrier
Operands: ¶
Operand | Description |
---|---|
nbarrier_id | 8-bit signless integer |
participant_thread_num | 8-bit signless integer |
Results: ¶
Result | Description |
---|---|
result | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.load
(xegpu::LoadGatherOp) ¶
Load a set of scattered data points from memory.
Syntax:
operation ::= `xegpu.load` $TensorDesc `,` $mask prop-dict attr-dict
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)
It (aka. load) load data per each work-item. The output describes the data being loaded at the subgroup level, so its size is consistent with the number of work-items in a subgroup. When the chunk size is larger than 2, the output vector is a 2D vector, with dim-1 correspoding to work-items, and dim-0 corresponding to the chunk size loaded by each work-item. Specially, there is a transpose effect on the result (as compared to the TensorDesc) due to the hardware implementation. Therefore, a transpose attribute is introduced on purpose, making sure users are aware of this implicit transformation.
The mask operand masks out memory access so that it is safe to pass out-of-boundary addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
Example 1:
%2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
vector<16xi1> -> vector<16xf32>
Example 2:
%2 = xegpu.load %1, %0 {transpose,
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<uncached>}
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<8x16xf32>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
transpose | ::mlir::UnitAttr | unit attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
mask | vector of 1-bit signless integer values of ranks 1 or 1-bit signless integer |
Results: ¶
Result | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.load_nd
(xegpu::LoadNdOp) ¶
Loads a n-D block from memory (represented by TensorDesc)to registers (represented by vector)
Syntax:
operation ::= `xegpu.load_nd` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
LoadNdOp essentially mimics the hardware block read instruction to read a block of data from memory to register. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked. VNNI transformation is an hardware feature for Intel GPU, which is used to do data packing during the load for B operand of matrix operation, if the bit width of the data type is less then 32 bits, e.g., fp16. And transpose is another Intel hardware feature, which will do transpose operation when loading the data if the bit width of the data type is fp32 or fp64. It implies that vnni and transpose cannot exit at the same time.
Example:
xegpu.load_nd %1 {transpose = [1, 0],
l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<uncached>,
l3_hint = #xegpu.cache_hint<streaming>}
: !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
packed | ::mlir::UnitAttr | unit attribute |
transpose | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
Results: ¶
Result | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
xegpu.nbarrier_arrive
(xegpu::NbarrierArriveOp) ¶
It signals the arrival at the named barrier.
Syntax:
operation ::= `xegpu.nbarrier_arrive` $nbarrier attr-dict `:` qualified(type($nbarrier))
NbarrierArriveOp signals the hardware (or other threads)
that the current thread has produced its data for the consumer threads. When
the hardware signalled by participant_thread_num
threads for the named barrier,
it will notify the threads waiting for the named barrier to continue their work.
Operands: ¶
Operand | Description |
---|---|
nbarrier | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.nbarrier_wait
(xegpu::NbarrierWaitOp) ¶
It waits for a named barrier.
Syntax:
operation ::= `xegpu.nbarrier_wait` $nbarrier attr-dict `:` qualified(type($nbarrier))
NbarrierWaitOp signals the hardware which named barrier the current thread is waiting for, such that it can get notified when the named barrier is completed.
Operands: ¶
Operand | Description |
---|---|
nbarrier | !xegpu.nbarrier a custom XeGPU type representing a barrier. |
xegpu.prefetch
(xegpu::PrefetchOp) ¶
Prefetches a set of scattered data points to cache
Syntax:
operation ::= `xegpu.prefetch` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))
It issues instructions to prefetch a set of scattered data points from memory to each level of the cache based on their cache policy. As compared to prefetch_nd, which works on non-scattered TensorDesc, it works on scattered TensorDesc instead.
Example:
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.prefetch_nd
(xegpu::PrefetchNdOp) ¶
Prefetches a n-D block to cache
Syntax:
operation ::= `xegpu.prefetch_nd` $TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))
It issues an instruction to prefetch a block of data from continuous memory regions to each level of the cache based on their cache policy.
Example:
xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<8x16xf16>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.store
(xegpu::StoreScatterOp) ¶
Store data to scattered memory locations.
Syntax:
operation ::= `xegpu.store` $value `,` $TensorDesc `,` $mask prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)
It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
and the dim-0 of the value corresponds to the chunk size stored per lane. So store_scatter
has transpose effect, which is similar to load_gather
. Therefore, a transpose attribute is
introduced on purpose, making sure users are aware of this implicit transformation.
Example 1:
%3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
Example 2:
%3 = xegpu.store %0, %1, %2 {transpose,
l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
transpose | ::mlir::UnitAttr | unit attribute |
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
TensorDesc | TensorDesc describing regions of interested data. |
mask | vector of 1-bit signless integer values of ranks 1 or 1-bit signless integer |
xegpu.store_nd
(xegpu::StoreNdOp) ¶
Stores a n-D block register region back to memory, currently only supports 2D
Syntax:
operation ::= `xegpu.store_nd` $value `,` $TensorDesc prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc))
StoreNdOp essentially mimics the hardware block write instruction io write a block of data from register into the memory region as described by the TensorDesc. It takes a set of optional cache hints for each level of cache, L1, L2 and L3. If hardware does not have a correspoding cache, Corresponding cache hint attribute will be masked.
Example:
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
l2_hint = #xegpu.cache_hint<write_back>,
l3_hint = #xegpu.cache_hint<write_through>}
: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
l1_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l2_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
l3_hint | ::mlir::xegpu::CachePolicyAttr | Describe the cache settings for prefetch/load/store operatorsEnum cases:
|
Operands: ¶
Operand | Description |
---|---|
value | vector of 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type values of ranks 1/2/3/4 or 1-bit signless integer or 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 1-bit signed integer or 8-bit signed integer or 16-bit signed integer or 32-bit signed integer or 64-bit signed integer or 1-bit unsigned integer or 8-bit unsigned integer or 16-bit unsigned integer or 32-bit unsigned integer or 64-bit unsigned integer or 16-bit float or 32-bit float or 64-bit float or bfloat16 type or tf32 type |
TensorDesc | TensorDesc describing regions of interested data. |
xegpu.update_nd_offset
(xegpu::UpdateNdOffsetOp) ¶
It updates the offsets for the TensorDesc.
Syntax:
operation ::= `xegpu.update_nd_offset` $TensorDesc `,`
custom<DynamicIndexList>($offsets, $const_offsets)
attr-dict `:` qualified(type($result))
The op updates the offset of the given TensorDesc. The offsets are relative offset to the current position in the number of elements. It will result in a same type TensorDesc as the input.
example:
%2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
Attributes: ¶
Attribute | MLIR Type | Description |
---|---|---|
const_offsets | ::mlir::DenseI64ArrayAttr | i64 dense array attribute |
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | variadic of index |
Results: ¶
Result | Description |
---|---|
result | TensorDesc describing regions of interested data. |
xegpu.update_offset
(xegpu::UpdateOffsetOp) ¶
It updates the offsets for the given tensor descriptor
Syntax:
operation ::= `xegpu.update_offset` $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
It behaves similar to update_nd_offset
in terms that
it updates offset of a TensorDesc, and the offsets are relative offset to
the current position in the number of elements. However, update_nd_offset
is to update the start point of a 2D block, so its offset constains two
elements representing the shift in each dimension. update_offset
is to
update the offset per work-item, so its offsets contains values representing
shifts for each work-item.
Example:
```mlir
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
%2 = xegpu.update_offset %1, %off :
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<>>, vector<4xindex>
```
Operands: ¶
Operand | Description |
---|---|
TensorDesc | TensorDesc describing regions of interested data. |
offsets | vector of index values of ranks 1 |
Results: ¶
Result | Description |
---|---|
result | TensorDesc describing regions of interested data. |
Attributes ¶
BlockTensorDescAttr ¶
a composite attribute for TensorDescType
Syntax:
#xegpu.block_tdesc_attr<
MemorySpaceAttr, # memory_space
IntegerAttr, # array_length
BoolAttr # boundary_check
>
BlockTensorDesc
(or block_tdesc_attr
) is a composite
attribute defined for TensorDescType
for describing following
properties of a TensorDesc
.
1. memory_space
: It describes where the data block described by the
TensorDesc is located, Global
device memory or Shared
local memory.
It is default to Global
.
2. array_length
: It describes how many horizontally consecutive blocks
will be loaded by a hardware load instruction. If the TensorDesc shape
is 8x16, with array_length = 2. The loaded block shape will be acctually
8x32. Its default value is 1.
3. boundary_check
: It is used to indicates the hardware whether to do
out-of-boundary check. The default value is true.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memory_space | MemorySpaceAttr | |
array_length | IntegerAttr | 1 |
boundary_check | BoolAttr | true |
CachePolicyAttr ¶
Describe the cache settings for prefetch/load/store operators
Syntax:
#xegpu.cache_hint<
::mlir::xegpu::CachePolicy # value
>
Enum cases:
- cached (
CACHED
) - uncached (
UNCACHED
) - streaming (
STREAMING
) - read_invalidate (
READ_INVALIDATE
) - write_back (
WRITE_BACK
) - write_through (
WRITE_THROUGH
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::CachePolicy | an enum of type CachePolicy |
FenceScopeAttr ¶
Describes the scope of fence. “workgroup” means that the scope is within each work group. “gpu” means the scope is across work groups within the gpu.
Syntax:
#xegpu.fence_scope<
::mlir::xegpu::FenceScope # value
>
Enum cases:
- workgroup (
Workgroup
) - gpu (
GPU
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::FenceScope | an enum of type FenceScope |
MemorySpaceAttr ¶
Describe the location of data described by a TensorDesc
:
Global device memory (Global
) or Shared local memory (SLM
).
Syntax:
#xegpu.memory_space<
::mlir::xegpu::MemorySpace # value
>
Enum cases:
- global (
Global
) - slm (
SLM
)
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
value | ::mlir::xegpu::MemorySpace | an enum of type MemorySpace |
SGMapAttr ¶
Describes the mapping between work item (WI) and the 2D tensor specified by the tensor descriptor.
To distribute the XeGPU operation to work items, the tensor_desc must be specified with the sg_map
attribute at the tensor description creation time.
Within the sg_map
, wi_layout
specifies the layout of work items,
describing the mapping of work items to the tensor.
wi_layout[0] x wi_layout[1] must be equal to the total number of work items within a subgroup.
wi_data
specifies the minimum number of data elements assigned to each work item for a single distribution.
E.g., #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]> In this example, the subgroup has 16 work items in wi_layout=[1, 16], each accessing 1 element as specified by wi_data=[1, 1].
wi_data[0] * wi_data[1]
can be greater than 1, meaning that each work item operates on multiple elements,
which is eventually lowered to “SIMT-flavor” vector, like SPIR-V vector or llvm vector, or packed to a storage data type.
The multiple elements indicated by wi_data
can only be from one dimension and must be contiguous in the memory along either dimension.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
wi_layout | ::llvm::ArrayRef<uint32_t> | |
wi_data | ::llvm::ArrayRef<uint32_t> |
ScatterTensorDescAttr ¶
a composite attribute for TensorDescType
Syntax:
#xegpu.scatter_tdesc_attr<
MemorySpaceAttr, # memory_space
IntegerAttr # chunk_size
>
ScatterTensorDesc
(or scatter_tdesc_attr
) is a composite
attribute defined for TensorDescType
for describing following
properties of a TensorDesc
.
1. memory_space
: It describes where the data block described by the
TensorDesc is located, Global
device memory or Shared
local memory.
It is default to Global
.
2. chunk_size
: indicates number of continious elements accessed for each
offset, default is 1. It is used with scattered
attr only.
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
memory_space | MemorySpaceAttr | |
chunk_size | IntegerAttr | 1 |
Types ¶
NbarrierType ¶
!xegpu.nbarrier a custom XeGPU type representing a barrier.
Syntax: !xegpu.nbarrier
TensorDescType ¶
TensorDesc describing regions of interested data.
TensorDesc is a type designed to describe regions of the interested data as well as some features that are unique to Intel hardware. Different with the builtin tensor type in MLIR, it essentially only contains the meta data, and doesn’t hold the data by itself. It is designed to mainly support 2D block load/store and DPAS (matrix multiplication instruction) on Intel GPU. It encodes the following information:
- shape: the sizes/shape of the intereted data block, e.g., 8x16 means 8 rows and each row contains 16 contiguous data element. The rows could be either contiguous or not, depends on whether the encoding attribute is set or not.
- element_type: the data type of the data element, e.g., f16, f32.
Similar to the builtin tensor, it also provides an optinal attribute to encoding the following information via the TensorDescAttr object:
- memory_space (xegpu::MemorySpace): [optional] where the data is located, global memory or shared memory. It is default to Global.
- array_length (int): [optional] The number of contiguous blocks with size as
shape
, that will be loaded by block load at a time. It is default to 1. - boundary_check (bool): [optional] indicates whether the operation detects the boundary and pads with zero for out-of-boundary access. It is default to do boundary check.
Syntax:
TensorDesc-type ::= `tensor_desc` `<` dim-list element-type (attr-list)? `>`
element-type ::= float-type | integer-type | index-type
dim-list := (static-dim-list `x`)?
static-dim-list ::= decimal-literal `x` decimal-literal
attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)? (, sg_map `<` wi_layout = value, wi_data = value `>`)?
Examples:
// A block TensorDesc with 8x16 i32 elements
xegpu.tensor_desc<8x16xi32>
// A block TensorDesc with 8x16 f32 elements
xegpu.tensor_desc<8x16xf32>
// A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
// A TensorDesc with a sg_map
xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
Parameters: ¶
Parameter | C++ type | Description |
---|---|---|
shape | ::llvm::ArrayRef<int64_t> | |
elementType | mlir::Type | |
encoding | mlir::Attribute | |
sg_map | mlir::Attribute |
Enums ¶
CmpFPredicate ¶
allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
Cases: ¶
Symbol | Value | String |
---|---|---|
AlwaysFalse | 0 | false |
OEQ | 1 | oeq |
OGT | 2 | ogt |
OGE | 3 | oge |
OLT | 4 | olt |
OLE | 5 | ole |
ONE | 6 | one |
ORD | 7 | ord |
UEQ | 8 | ueq |
UGT | 9 | ugt |
UGE | 10 | uge |
ULT | 11 | ult |
ULE | 12 | ule |
UNE | 13 | une |
UNO | 14 | uno |
AlwaysTrue | 15 | true |
CmpIPredicate ¶
allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
Cases: ¶
Symbol | Value | String |
---|---|---|
eq | 0 | eq |
ne | 1 | ne |
slt | 2 | slt |
sle | 3 | sle |
sgt | 4 | sgt |
sge | 5 | sge |
ult | 6 | ult |
ule | 7 | ule |
ugt | 8 | ugt |
uge | 9 | uge |
IntegerOverflowFlags ¶
Integer overflow arith flags
Cases: ¶
Symbol | Value | String |
---|---|---|
none | 0 | none |
nsw | 1 | nsw |
nuw | 2 | nuw |
RoundingMode ¶
Floating point rounding mode
Cases: ¶
Symbol | Value | String |
---|---|---|
to_nearest_even | 0 | to_nearest_even |
downward | 1 | downward |
upward | 2 | upward |
toward_zero | 3 | toward_zero |
to_nearest_away | 4 | to_nearest_away |
AtomicRMWKind ¶
allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
Cases: ¶
Symbol | Value | String |
---|---|---|
addf | 0 | addf |
addi | 1 | addi |
assign | 2 | assign |
maximumf | 3 | maximumf |
maxs | 4 | maxs |
maxu | 5 | maxu |
minimumf | 6 | minimumf |
mins | 7 | mins |
minu | 8 | minu |
mulf | 9 | mulf |
muli | 10 | muli |
ori | 11 | ori |
andi | 12 | andi |
maxnumf | 13 | maxnumf |
minnumf | 14 | minnumf |
FastMathFlags ¶
Floating point fast math flags
Cases: ¶
Symbol | Value | String |
---|---|---|
none | 0 | none |
reassoc | 1 | reassoc |
nnan | 2 | nnan |
ninf | 4 | ninf |
nsz | 8 | nsz |
arcp | 16 | arcp |
contract | 32 | contract |
afn | 64 | afn |
fast | 127 | fast |
CachePolicy ¶
Cache policy
Cases: ¶
Symbol | Value | String |
---|---|---|
CACHED | 0 | cached |
UNCACHED | 1 | uncached |
STREAMING | 2 | streaming |
READ_INVALIDATE | 3 | read_invalidate |
WRITE_BACK | 4 | write_back |
WRITE_THROUGH | 5 | write_through |
FenceScope ¶
The enumeration for the scope of fence operation.
Cases: ¶
Symbol | Value | String |
---|---|---|
Workgroup | 0 | workgroup |
GPU | 1 | gpu |
MemorySpace ¶
The address space of the memory the tensor descritor is created for
Cases: ¶
Symbol | Value | String |
---|---|---|
Global | 0 | global |
SLM | 3 | slm |