MLIR

Multi-Level IR Compiler Framework

'amx' Dialect

The Intel Advanced Matrix Extensions (AMX) provide a tile matrix multiply unit (TMUL), a tile control register (TILECFG), and eight tile registers TMM0 through TMM7 (TILEDATA).

This AMX dialect provides a bridge between MLIR concepts such as vectors and memrefs and the lower level LLVM IR support of AMX. The dialect is split into user-facing AMX ops (AMX_Op) and backend-facing intrinsic ops (AMX_IntrOp).

Note that since configuration changes (implicit at dialect level) are costly, it is highly recommended to use the AMX dialect on same-shaped vectors, at least within a single method.

For details, see the Intel documentation: https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html

Operations 

source

amx.tdpbf16ps (amx::x86_amx_tdpbf16ps) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»integer
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tdpbssd (amx::x86_amx_tdpbssd) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»integer
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tdpbsud (amx::x86_amx_tdpbsud) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»integer
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tdpbusd (amx::x86_amx_tdpbusd) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»integer
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tdpbuud (amx::x86_amx_tdpbuud) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»integer
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tdpfp16ps (amx::x86_amx_tdpfp16ps) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»integer
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type
«unnamed»LLVM dialect-compatible type

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tile_load (amx::TileLoadOp) 

Tile load operation

Syntax:

operation ::= `amx.tile_load` $base `[` $indices `]` attr-dict `:` type($base) `into` qualified(type($res))

Loads a tile from memory defined by a base and indices, with the shape defined by the 2-dim vector type of the result. This is eventually lowered into the “tileloadd” instruction with the corresponding tile configuration.

Example:

  %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
basememref of any type values
indicesvariadic of index

Results: 

ResultDescription
restile of 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer values

amx.tile_mulf (amx::TileMulFOp) 

Tile multiplication operation (floating-point)

Syntax:

operation ::= `amx.tile_mulf` $lhs `,` $rhs `,` $acc attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc))

Multiplies a “m x k” tile with a “k x n” tile and accumulates the results into a “m x n” destination tile. Supports “f32 <- bf16 x bf16” (with pairs of “bf16”). The operation is eventually lowered into the “tdpbf16ps” instruction with the corresponding tile configuration.

Example:

  %0 = amx.tile_mulf %a, %b, %c
    : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
lhstile of 16-bit float or bfloat16 type values
rhstile of 16-bit float or bfloat16 type values
acctile of 32-bit float values

Results: 

ResultDescription
restile of 32-bit float values

amx.tile_muli (amx::TileMulIOp) 

Tile multiplication operation (integer)

Syntax:

operation ::= `amx.tile_muli` $lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc))

Multiplies a “m x k” tile with a “k x n” tile and accumulates the results into a “m x n” destination tile. Supports all “si32 <- s/ui8 x s/ui8” combinations (4 bytes packed into dwords in the columns of both the source operand tiles; the zero or sign extension is specified with the attributes and default to sign extended). The operation is eventually lowered into one of the “tdpbssd”, “tdpbsud”, “tdpbusd”, or “tdpbuud” instructions with the corresponding tile configuration.

Example:

  %0 = amx.tile_muli %a zext, %b zext, %c
    : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
isZextLhs::mlir::UnitAttrunit attribute
isZextRhs::mlir::UnitAttrunit attribute

Operands: 

OperandDescription
lhstile of 8-bit signless integer values
rhstile of 8-bit signless integer values
acctile of 32-bit signless integer values

Results: 

ResultDescription
restile of 32-bit signless integer values

amx.tile_store (amx::TileStoreOp) 

Tile store operation

Syntax:

operation ::= `amx.tile_store` $base `[` $indices `]` `,` $val attr-dict `:` type($base) `,` qualified(type($val))

Stores a tile to memory defined by a base and indices, with the shape defined by the 2-dim vector type of the value. This is eventually lowered into the “tilestored” instruction with the corresponding tile configuration.

Example:

  amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>

Operands: 

OperandDescription
basememref of any type values
indicesvariadic of index
valtile of 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer values

amx.tile_zero (amx::TileZeroOp) 

Tile zero operation

Syntax:

operation ::= `amx.tile_zero` attr-dict `:` qualified(type($res))

Zeroes the destination tile, with the shape defined by the 2-dim vector type of the result. This is eventually lowered into the “tilezero” instruction with the corresponding tile configuration.

Example:

  %0 = amx.tile_zero : !amx.tile<16x16xbf16>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
restile of 32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer values

amx.tileloadd64 (amx::x86_amx_tileloadd64) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»LLVM pointer type
«unnamed»integer

Results: 

ResultDescription
resLLVM dialect-compatible type

amx.tilestored64 (amx::x86_amx_tilestored64) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer
«unnamed»LLVM pointer type
«unnamed»integer
«unnamed»LLVM dialect-compatible type

amx.tilezero (amx::x86_amx_tilezero) 

Operands: 

OperandDescription
«unnamed»integer
«unnamed»integer

Results: 

ResultDescription
resLLVM dialect-compatible type

Types 

TileType 

AMX 2D tile to be used by AMX opertaions.

This type is used to represent values in AMX tile registers. All AMX operations work on AMX tiles and these tiles cannot be used in other operations directly. LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and element type for IR verification and lowering to LLVMIR dialect.

Parameters: 

ParameterC++ typeDescription
shape::llvm::ArrayRef<int64_t>
elementType::mlir::Type32-bit float or 16-bit float or bfloat16 type or 32-bit signless integer or 8-bit signless integer