'amx' Dialect
The Intel Advanced Matrix Extensions (AMX) provide a tile matrix multiply unit (TMUL), a tile control register (TILECFG), and eight tile registers TMM0 through TMM7 (TILEDATA).
This AMX dialect provides a bridge between MLIR concepts such as
vectors and memrefs and the lower level LLVM IR support of AMX.
Note that since configuration changes (implicit at dialect level) are costly, it is highly recommended to use the AMX dialect on same-shaped vectors, at least within a single method.
For details, see the Intel documentation: https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
Operations ¶
amx.tile_load (amx::TileLoadOp) ¶
Tile load operation
Syntax:
operation ::= `amx.tile_load` $base `[` $indices `]` (`,` $stride^ )? attr-dict`:` type($base) `into` qualified(type($res))
Loads a tile from memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the result.
The tile’s rows are populated by reading contiguous elements starting
at the base. For each tile row, the base is incremented by stride
number of elements.
The tile is loaded using the following indexing scheme:
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
tile[row, col] = mem_row[col]
If the stride is not provided, then the base buffer must be at least
2-dimensional, and the stride is automatically inferred and corresponds
to the stride of the buffer’s second innermost dimension.
The operation is eventually lowered into the “tileloadd” instruction with the corresponding tile configuration.
With the write memory effect, each amx.tile_load operation serves as
a compilation hint to use a separate tile register.
Example:
// Tile load from a 2-D memref with implicit stride.
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
// Tile load from a 1-D memref with explicit stride.
%0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
Traits: AttrSizedOperandSegments
Interfaces: AMXIntrinsicOpInterface, MemoryEffectOpInterface (MemoryEffectOpInterface), OneToOneIntrinsicOpInterface
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Operands: ¶
| Operand | Description |
|---|---|
base | memref of any type values |
indices | variadic of index |
stride | index |
Results: ¶
| Result | Description |
|---|---|
res | tile of 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer values |
amx.tile_mulf (amx::TileMulFOp) ¶
Tile multiplication operation (floating-point)
Syntax:
operation ::= `amx.tile_mulf` $lhs `,` $rhs `,` $acc attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc))
Multiplies a “m x k” tile with a “k x n” tile and accumulates the results into a “m x n” destination tile. Supports “f32 <- bf16 x bf16” (with pairs of “bf16”).
The operation is eventually lowered into the “tdpbf16ps” instruction with the corresponding tile configuration.
Example:
%0 = amx.tile_mulf %a, %b, %c
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
Traits: AlwaysSpeculatableImplTrait
Interfaces: AMXIntrinsicOpInterface, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OneToOneIntrinsicOpInterface
Effects: MemoryEffects::Effect{}
Operands: ¶
| Operand | Description |
|---|---|
lhs | tile of 16-bit float or bfloat16 type values |
rhs | tile of 16-bit float or bfloat16 type values |
acc | tile of 32-bit float values |
Results: ¶
| Result | Description |
|---|---|
res | tile of 32-bit float values |
amx.tile_muli (amx::TileMulIOp) ¶
Tile multiplication operation (integer)
Syntax:
operation ::= `amx.tile_muli` $lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc))
Multiplies a “m x k” tile with a “k x n” tile and accumulates the results into a “m x n” destination tile. Supports all “si32 <- s/ui8 x s/ui8” combinations (4 bytes packed into dwords in the columns of both the source operand tiles; the zero or sign extension is specified with the attributes and default to sign extended).
The operation is eventually lowered into one of the “tdpbssd”, “tdpbsud”, “tdpbusd”, or “tdpbuud” instructions with the corresponding tile configuration.
Example:
%0 = amx.tile_muli %a zext, %b zext, %c
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
Traits: AlwaysSpeculatableImplTrait
Interfaces: AMXIntrinsicOpInterface, ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), OneToOneIntrinsicOpInterface
Effects: MemoryEffects::Effect{}
Attributes: ¶
| Attribute | MLIR Type | Description |
|---|---|---|
isZextLhs | ::mlir::UnitAttr | unit attribute |
isZextRhs | ::mlir::UnitAttr | unit attribute |
Operands: ¶
| Operand | Description |
|---|---|
lhs | tile of 8-bit signless integer values |
rhs | tile of 8-bit signless integer values |
acc | tile of 32-bit signless integer values |
Results: ¶
| Result | Description |
|---|---|
res | tile of 32-bit signless integer values |
amx.tile_store (amx::TileStoreOp) ¶
Tile store operation
Syntax:
operation ::= `amx.tile_store` $base `[` $indices `]` `,` $val (`,` $stride^ )?attr-dict `:` type($base) `,` qualified(type($val))
Stores a tile to memory defined by a base and indices, with the
shape defined by the 2-dim vector type of the value.
The tile’s rows are written contiguously to the buffer starting at
the base. For each tile row, the base is incremented by stride
number of elements.
The tile is stored using the following indexing scheme:
for row in enumerate(tile_rows):
mem_row = base[i0, i1, ..., iN + row * stride]
for col in enumerate(tile_cols):
mem_row[col] = tile[row, col]
If the stride is not provided, then the base buffer must be at least
2-dimensional, and the stride is automatically inferred and corresponds
to the stride of the buffer’s second innermost dimension.
The operation is eventually lowered into the “tilestored” instruction with the corresponding tile configuration.
Example:
// Tile store to a 2-D memref with implicit stride.
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
// Tile store to a 1-D memref with explicit stride.
amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
Traits: AttrSizedOperandSegments
Interfaces: AMXIntrinsicOpInterface, OneToOneIntrinsicOpInterface
Operands: ¶
| Operand | Description |
|---|---|
base | memref of any type values |
indices | variadic of index |
val | tile of 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer values |
stride | index |
amx.tile_zero (amx::TileZeroOp) ¶
Tile zero operation
Syntax:
operation ::= `amx.tile_zero` attr-dict `:` qualified(type($res))
Zeroes the destination tile, with the shape defined by the 2-dim vector type of the result.
The operation is eventually lowered into the “tilezero” instruction with the corresponding tile configuration.
With the write memory effect, each amx.tile_zero operation serves as
a compilation hint to use a separate tile register.
Example:
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
Interfaces: AMXIntrinsicOpInterface, MemoryEffectOpInterface (MemoryEffectOpInterface), OneToOneIntrinsicOpInterface
Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}
Results: ¶
| Result | Description |
|---|---|
res | tile of 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer values |
Types ¶
TileType ¶
AMX 2D tile to be used by AMX opertaions.
This type is used to represent values in AMX tile registers. All AMX operations work on AMX tiles and these tiles cannot be used in other operations directly. LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and element type for IR verification and lowering to LLVMIR dialect.
Parameters: ¶
| Parameter | C++ type | Description |
|---|---|---|
| shape | ::llvm::ArrayRef<int64_t> | |
| elementType | ::mlir::Type | 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer |
MLIR