MLIR

Multi-Level IR Compiler Framework

'arm_sme' Dialect

Basic dialect to target Arm SME architectures This dialect contains the definitions necessary to target Arm SME scalable matrix operations.

Sources: https://developer.arm.com/documentation/ddi0616 https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions

Operation definition 

arm_sme.cast_tile_to_vector (arm_sme::CastTileToVector) 

Cast from tile id to 2-d scalable vector type

Syntax:

operation ::= `arm_sme.cast_tile_to_vector` $tile_id attr-dict `:` type($tile_id) `to` type($vector)

A cast_tile_to_vector operation does a cast from a tile id to a 2-d scalable vector type, which represents an SME “virtual tile”. This would normally be used when lowering operations that return “virtual tile” vector types to model the output. This is required to preserve dataflow as SME intrinsics have no return values.

Example:

Input:

%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

After lowering vector.load:

%tile_id = arm_sme.get_tile_id : i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

In the example above, the vector.load can’t be replaced with an SME intrinsic that has no outputs since it is used by the vector.store. However, by inserting a cast_tile_to_vector op after the load intrinsics the vector.load can be replaced. This enables “local” rewrites on individual vector ops, rather than “global” rewrites that would have to look at the vector op uses and also lower them.

Canonicalization will look through arm_sme.cast_tile_to_vector and fold the cast away if it comes from a arm_sme.cast_vector_to_tile.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
tile_id8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer

Results: 

ResultDescription
vectorvector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

arm_sme.cast_vector_to_tile (arm_sme::CastVectorToTile) 

Cast from 2-d scalable vector type to tile id

Syntax:

operation ::= `arm_sme.cast_vector_to_tile` $vector attr-dict `:` type($vector) `to` type($tile_id)

A cast_vector_to_tile operation does a cast from a 2-d scalable vector type, which represents an SME “virtual tile”, to a tile id. This is required to preserve dataflow as the SME intrinsics have no return values.

Example:

Input:

%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

After lowering vector.store:

%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}

Canonicalization will look through arm_sme.cast_vector_to_tile and fold the cast away if it comes from a arm_sme.cast_tile_to_vector.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
vectorvector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

Results: 

ResultDescription
tile_id8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer

arm_sme.get_tile_id (arm_sme::GetTileID) 

Returns an SME “virtual tile” id

Syntax:

operation ::= `arm_sme.get_tile_id` attr-dict `:` type($tile_id)

A get_tile_id operation returns a scalar integer representing an SME “virtual tile” id. The bitwidth of the scalar indicates the element bitwidth of the “virtual tile”.

The scope of a tile id is a function and cannot be passed or returned from functions.

Example:

// Allocate and return an 8-bit element "virtual tile" id
%za0_b = arm_sme.get_tile_id : i8

Example:

// Allocate and return two 16-bit element "virtual tile" ids
%za0_h = arm_sme.get_tile_id : i16
%za1_h = arm_sme.get_tile_id : i16

Example:

// Allocate and return an 128-bit element "virtual tile" id
%za0_q = arm_sme.get_tile_id : i128

Results: 

ResultDescription
tile_id8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer

arm_sme.intr.ld1b.horiz (arm_sme::aarch64_sme_ld1b_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1b.vert (arm_sme::aarch64_sme_ld1b_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1d.horiz (arm_sme::aarch64_sme_ld1d_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1d.vert (arm_sme::aarch64_sme_ld1d_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1h.horiz (arm_sme::aarch64_sme_ld1h_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1h.vert (arm_sme::aarch64_sme_ld1h_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1q.horiz (arm_sme::aarch64_sme_ld1q_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1q.vert (arm_sme::aarch64_sme_ld1q_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1w.horiz (arm_sme::aarch64_sme_ld1w_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.ld1w.vert (arm_sme::aarch64_sme_ld1w_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.mopa (arm_sme::aarch64_sme_mopa) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.mopa.wide (arm_sme::aarch64_sme_mopa_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.mops (arm_sme::aarch64_sme_mops) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.mops.wide (arm_sme::aarch64_sme_mops_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.read.horiz (arm_sme::aarch64_sme_read_horiz) 

Operands: 

OperandDescription
vectorof ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1
pgof ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

Results: 

ResultDescription
resLLVM dialect-compatible type

arm_sme.intr.read.vert (arm_sme::aarch64_sme_read_vert) 

Operands: 

OperandDescription
vectorof ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1
pgof ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

Results: 

ResultDescription
resLLVM dialect-compatible type

arm_sme.intr.smopa.wide (arm_sme::aarch64_sme_smopa_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.smops.wide (arm_sme::aarch64_sme_smops_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.st1b.horiz (arm_sme::aarch64_sme_st1b_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1b.vert (arm_sme::aarch64_sme_st1b_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1d.horiz (arm_sme::aarch64_sme_st1d_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1d.vert (arm_sme::aarch64_sme_st1d_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1h.horiz (arm_sme::aarch64_sme_st1h_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1h.vert (arm_sme::aarch64_sme_st1h_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1q.horiz (arm_sme::aarch64_sme_st1q_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1q.vert (arm_sme::aarch64_sme_st1q_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1w.horiz (arm_sme::aarch64_sme_st1w_horiz) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.st1w.vert (arm_sme::aarch64_sme_st1w_vert) 

Operands: 

OperandDescription
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2/1
«unnamed»LLVM pointer type
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer

arm_sme.intr.str (arm_sme::aarch64_sme_str) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»LLVM pointer type

arm_sme.intr.sumopa.wide (arm_sme::aarch64_sme_sumopa_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.sumops.wide (arm_sme::aarch64_sme_sumops_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.umopa.wide (arm_sme::aarch64_sme_umopa_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.umops.wide (arm_sme::aarch64_sme_umops_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.usmopa.wide (arm_sme::aarch64_sme_usmopa_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.usmops.wide (arm_sme::aarch64_sme_usmops_wide) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 1-bit signless integer values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2
«unnamed»scalable vector of 8-bit signless integer or 16-bit signless integer or bfloat16 type or 16-bit float or 32-bit float or 64-bit float values of length 16/8/4/2

arm_sme.intr.write.horiz (arm_sme::aarch64_sme_write_horiz) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer
pgof ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1
vectorof ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1

arm_sme.intr.write.vert (arm_sme::aarch64_sme_write_vert) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer
«unnamed»32-bit signless integer
pgof ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1
vectorof ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1

arm_sme.intr.za.disable (arm_sme::aarch64_sme_za_disable) 

arm_sme.intr.za.enable (arm_sme::aarch64_sme_za_enable) 

arm_sme.intr.zero (arm_sme::aarch64_sme_zero) 

Operands: 

OperandDescription
«unnamed»32-bit signless integer

arm_sme.load_tile_slice (arm_sme::LoadTileSliceOp) 

Tile slice load and update operation

Syntax:

operation ::= `arm_sme.load_tile_slice` $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
              attr-dict `:` type($base) `,` type($result)

Loads a 1D tile slice from memory into a 2D SME “virtual tile”. The tile slice is defined by the dimension of the 2D scalable vector type pointed by the index. A tile slice index describes where in the input tile the tile slice is loaded to. An optional tile slice layout attribute specifies whether the tile slice being loaded at the given index is horizontal (default) or vertical. The updated tile is returned as the result.

The slice of memory read is defined by a base and indices and must be contiguous. The memref must be either rank 1 or rank 2, have dynamic dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the result.

Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.

%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>

Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.

%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>

Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.

%tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>

Interfaces: InferTypeOpInterface

Attributes: 

AttributeMLIR TypeDescription
layout::mlir::arm_sme::TileSliceLayoutAttr
Layout of a tile slice

Enum cases:

  • horizontal (Horizontal)
  • vertical (Vertical)

Operands: 

OperandDescription
basememref of any type values
tilevector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
indicesindex
tile_slice_indexindex

Results: 

ResultDescription
resultvector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

arm_sme.move_tile_slice_to_vector (arm_sme::MoveTileSliceToVectorOp) 

Move slice of a 2-D tile to a 1-D scalable vector

Syntax:

operation ::= `arm_sme.move_tile_slice_to_vector` $tile `[` $tile_slice_index `]` attr-dict
              `:` type($result) `from` type($tile)

The tile slice to vector operation extracts a 1-D scalable slice from a 2-D scalable tile at the given index. A tile slice is a 1-D vector of horizontally or vertically contiguous elements within a ZA tile. Horizontal tile slices are currently assumed when lowering to intrinsics.

Example 1: Extract vector<[16]xi8> from tile at the given index.

%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>

Example 2: Extract vector<[2]xf64> from tile at the given index.

%slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
tilevector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
tile_slice_indexindex

Results: 

ResultDescription
resultof ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1

arm_sme.move_vector_to_tile_slice (arm_sme::MoveVectorToTileSliceOp) 

Move 1-D scalable vector to slice of 2-D tile

Syntax:

operation ::= `arm_sme.move_vector_to_tile_slice` $vector `,` $tile `,` $tile_slice_index
              attr-dict `:` type($vector) `into` type($result)

The vector to tile slice operation moves a 1-D scalable vector to a slice of a 2-D scalable vector tile at the given index. The type of the 1-D scalable vector to be moved must match the type of the tile slice. A tile slice is a 1-D vector of horizontally or vertically contiguous elements within a ZA tile. Horizontal tile slices are currently assumed when lowering to intrinsics. The updated tile is returned as the result.

Example 1: Move a vector<[16]xi8> into tile at given index.

%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>

Example 2: Move a vector<[2]xf64> into tile at given index.

%tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>

Interfaces: InferTypeOpInterface

Operands: 

OperandDescription
vectorof ranks 1scalable vector of 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer or 16-bit float or bfloat16 type or 32-bit float or 64-bit float values of length 16/8/4/2/1
tilevector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
tile_slice_indexindex

Results: 

ResultDescription
resultvector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

arm_sme.store_tile_slice (arm_sme::StoreTileSliceOp) 

Tile slice store operation

Syntax:

operation ::= `arm_sme.store_tile_slice` $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
              attr-dict `:` type($base) `,` type($tile)

Stores a 1D tile slice from a 2D SME “virtual tile” into memory. The tile slice is defined by the dimension of the 2D scalable vector type pointed by the index. A tile slice index describes where in the input tile the tile slice is stored from. An optional tile slice layout attribute specifies whether the tile slice being stored from the given index is horizontal (default) or vertical.

The slice of memory written is defined by a base and indices and must be contiguous. The memref must be either rank 1 or rank 2, have dynamic dimensions since the operation is scalable, and the element type must be a scalar that matches the element type of the input tile.

Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.

arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>

Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.

arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>

Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.

arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>

Attributes: 

AttributeMLIR TypeDescription
layout::mlir::arm_sme::TileSliceLayoutAttr
Layout of a tile slice

Enum cases:

  • horizontal (Horizontal)
  • vertical (Vertical)

Operands: 

OperandDescription
tilevector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
tile_slice_indexindex
basememref of any type values
indicesindex

arm_sme.tile_load (arm_sme::TileLoadOp) 

Tile load operation

Syntax:

operation ::= `arm_sme.tile_load` $base `[` $indices `]` (`,` $layout^)? attr-dict `:` type($base) `,` type($result)

Loads a 2D SME “virtual tile” from memory defined by a base and indices, with the shape defined by the 2D scalable vector type of the result tile. An optional tile slice layout attribute specifies whether the slices of the tile being loaded are horizontal (default) or vertical. The slice of memory must be contiguous. The memref must be either rank 1 or rank 2 with dynamic dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result.

Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).

%tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>

Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.

%tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>

Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.

%tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>

Attributes: 

AttributeMLIR TypeDescription
layout::mlir::arm_sme::TileSliceLayoutAttr
Layout of a tile slice

Enum cases:

  • horizontal (Horizontal)
  • vertical (Vertical)

Operands: 

OperandDescription
basememref of any type values
indicesindex

Results: 

ResultDescription
resultvector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

arm_sme.tile_store (arm_sme::TileStoreOp) 

Tile store operation

Syntax:

operation ::= `arm_sme.tile_store` $valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict `:` type($base) `,` type($valueToStore)

Stores a 2D SME “virtual tile” to memory defined by a base and indices, with the shape defined by the 2D scalable vector type of the tile being stored. An optional tile slice layout attribute specifies whether the slices of the tile being stored are horizontal (default) or vertical. The slice of memory must be contiguous. The memref must be either rank 1 or rank 2 with dynamic dimensions, since the operation is scalable, and the element type must be a scalar that matches the element type of the result.

Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).

arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>

Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.

arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>

Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.

arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>

Attributes: 

AttributeMLIR TypeDescription
layout::mlir::arm_sme::TileSliceLayoutAttr
Layout of a tile slice

Enum cases:

  • horizontal (Horizontal)
  • vertical (Vertical)

Operands: 

OperandDescription
valueToStorevector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values
basememref of any type values
indicesindex

arm_sme.zero (arm_sme::ZeroOp) 

Initialize the two-dimensional ZA array with 0s

Syntax:

operation ::= `arm_sme.zero` attr-dict `:` type($res)

Initialise ZA with 0. This operation is convenient wrapper for the SME zero intrinsic and instruction.

Example 1: Zero an 8-bit element ZA tile.

%0 = arm_sme.zero : vector<[16]x[16]xi8>

Example 2: Zero a 64-bit element ZA tile.

%0 = arm_sme.zero : vector<[2]x[2]xi64>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Results: 

ResultDescription
resvector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values

Attribute definition 

TileSliceLayoutAttr 

Layout of a tile slice

Syntax:

#arm_sme.layout<
  ::mlir::arm_sme::TileSliceLayout   # value
>

Enum cases:

  • horizontal (Horizontal)
  • vertical (Vertical)

Parameters: 

ParameterC++ typeDescription
value::mlir::arm_sme::TileSliceLayoutan enum of type TileSliceLayout