mlir.dialects._shard_ops_gen¶
Attributes¶
Classes¶
Concatenates all tensor slices from a device group defined by |
|
Reduces the input tensor across all devices within the groups defined by |
|
Within each device group defined by |
|
Each participant logically splits its input along split_axis, |
|
Copies the input tensor on |
|
Concatenates all tensor slices from a device group defined by |
|
This operation returns the sharding of the given tensor as a Sharding. |
|
The shard.grid operation is a symbol operation that identifies a specific |
|
Example: |
|
Example: |
|
It is used in the SPMD format of IR. |
|
Receive tensor from device |
|
Reduces the input tensor across all devices within the groups defined by |
|
Reduces the input tensor across all devices within the groups defined by |
|
For each device group defined by |
|
Send input tensor to device |
|
The shard.shard operation is designed to specify and guide the sharding |
|
The device/process id is a multi-index of the device/process in the shard. |
|
The Sharding specifies how a tensor is sharded and distributed across the |
|
Within each device group defined by |
|
This operation updates halo regions of shards, e.g. if their sharding |
Functions¶
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Module Contents¶
- mlir.dialects._shard_ops_gen._ods_ir¶
- class mlir.dialects._shard_ops_gen._Dialect(descriptor: object)¶
Bases:
_ods_ir- DIALECT_NAMESPACE = 'shard'¶
- class mlir.dialects._shard_ops_gen.AllGatherOp(result, grid, input, gather_axis, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irConcatenates all tensor slices from a device group defined by
grid_axesalong the tensor dimensiongather_axisand replicates the result across all devices in the group.Example:
shard.grid @grid0(shape = 2x2) ... %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1 : tensor<2x2xi8> -> tensor<2x4xi8>
Input:
+-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+
Result:
gather tensor axis 1 ------------> +-------------+ | 1 2 5 6 | <- devices (0, 0) and (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 0) and (1, 1) | 11 12 15 16 | +-------------+
- OPERATION_NAME = 'shard.all_gather'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- gather_axis() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.all_gather(result, grid, input, gather_axis, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.AllReduceOp(result, grid, input, *, grid_axes=None, reduction=None, loc=None, ip=None)¶
Bases:
_ods_irReduces the input tensor across all devices within the groups defined by
grid_axes, using the specified reduction method. The operation performs an element-wise reduction over the tensor slices from all devices in each group. Each device in a group receives a replicated copy of the reduction result. The accumulation element type is determined by the result type and does not need to match the input element type. Before performing the reduction, each input element is converted to the result element type.Attributes:
reduction: Indicates the reduction method.Example:
%1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max> : tensor<3x4xf32> -> tensor<3x4xf64>
- OPERATION_NAME = 'shard.all_reduce'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- reduction() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.all_reduce(result, grid, input, *, grid_axes=None, reduction=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.AllSliceOp(result, grid, input, slice_axis, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irWithin each device group defined by
grid_axes, slices the input tensor along theslice_axisdimension. It can be viewed as the inverse of an all-gather if the input data is replicated along theslice_axis. Each process simply crops its local data to the slice corresponding to its in-group device index. Notice:AllSliceOpdoes not involve any communication between devices and devices within a group may not have replicated input data.Example:
shard.grid @grid0(shape = 2x2) ... %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<2x4xi8> -> tensor<2x2xi8>
Input:
+-------------+ | 1 2 5 6 | <- devices (0, 0) and (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 0) and (1, 1) | 11 12 15 16 | +-------------+
Result:
slice tensor axis 1 ------------> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+
- OPERATION_NAME = 'shard.all_slice'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- slice_axis() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.all_slice(result, grid, input, slice_axis, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.AllToAllOp(result, grid, input, split_axis, concat_axis, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irEach participant logically splits its input along split_axis, then scatters the resulting pieces across the group defined by
grid_axes. After receiving data pieces from other participants’ scatters, it concatenates them along concat_axis to produce the final result.Example:
shard.grid @grid0(shape = 3) ... %1 = shard.all_to_all %0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 0 : tensor<3x2xi8> -> tensor<3x2xi8>
Input:
device device device (0) (1) (2) +-------+-------+-------+ | split and concat along | 11 12 | 21 22 | 31 32 | | tensor axis 0 | 13 14 | 23 24 | 33 34 | ↓ | 15 16 | 25 26 | 35 36 | +-------+-------+-------+
Result:
device device device (0) (1) (2) +-------+-------+-------+ | 11 12 | 13 14 | 15 16 | | 21 22 | 23 24 | 25 26 | | 31 32 | 33 34 | 35 36 | +-------+-------+-------+
- OPERATION_NAME = 'shard.all_to_all'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- split_axis() _ods_ir¶
- concat_axis() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.all_to_all(result, grid, input, split_axis, concat_axis, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.BroadcastOp(result, grid, input, root, root_dynamic, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irCopies the input tensor on
rootto all devices in each group defined bygrid_axes. Therootdevice is defined by its in-group multi-index. The contents of input tensors on non-root devices are ignored.Example:
shard.grid @grid0(shape = 2x2) %1 = shard.broadcast %0 on @grid0 grid_axes = [0] root = [0] : (tensor<2xi8>) -> tensor<2xi8>
Input:
+-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ device (1, 0) -> | * * | * * | <- device (1, 1) +-------+-------+Output:
+-------+-------+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) +-------+-------+ device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1) +-------+-------+
- OPERATION_NAME = 'shard.broadcast'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- root_dynamic() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- root() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.broadcast(result, grid, input, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.GatherOp(result, grid, input, gather_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irConcatenates all tensor slices from a device group defined by
grid_axesalong the tensor dimensiongather_axisand returns the resulting tensor on eachrootdevice. The result on all other (non-root) devices is undefined. Therootdevice is defined by its in-group multi-index.Example:
shard.grid @grid0(shape = 2x2) ... %1 = shard.gather %0 on @grid0 grid_axes = [1] gather_axis = 1 root = [1] : (tensor<2x2xi8>) -> tensor<2x4xi8>
Input:
gather tensor axis 1 ------------> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+
Result:
+-------------+ | 1 2 5 6 | <- devices (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 1) | 11 12 15 16 | +-------------+
Devices
(0, 0)and(1, 0)have undefined result.- OPERATION_NAME = 'shard.gather'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- root_dynamic() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- gather_axis() _ods_ir¶
- root() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.gather(result, grid, input, gather_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.GetShardingOp(source, *, results=None, loc=None, ip=None)¶
Bases:
_ods_irThis operation returns the sharding of the given tensor as a Sharding.
- OPERATION_NAME = 'shard.get_sharding'¶
- _ODS_REGIONS = (0, True)¶
- source() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.get_sharding(source, *, results=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.GridOp(sym_name, shape, *, loc=None, ip=None)¶
Bases:
_ods_irThe shard.grid operation is a symbol operation that identifies a specific grid. The operation has three attributes:
#.
sym_name: This attribute uniquely identifies the name of the grid. This name serves as a symbolic reference to the grid throughout the MLIR module, allowing for consistent referencing and easier debugging. #.shape: This attribute represents the shape of the device grid. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example2x?x4.Example:
// A device grid with 3 axes, the total device number is 4 * 8 * 12 // The dimension sizes are 4, 8, 12 shard.grid @grid0(shape = 4x8x12) // A device grid with 2 axes, the total device number is unknown // The first dimension size is 4 and the second is unknown shard.grid @grid1(shape = 4x?) // A device grid with 2 axes, the total device number is unknown // The first dimension size is unknown and the second is 4 shard.grid @grid2(shape = ?x4) // A device grid with 2 axes, the number of devices along both axes // is unknown shard.grid @grid3(shape = ?x?)
- OPERATION_NAME = 'shard.grid'¶
- _ODS_REGIONS = (0, True)¶
- sym_name() _ods_ir¶
- shape() _ods_ir¶
- class mlir.dialects._shard_ops_gen.GridShapeOp(result, grid, *, axes=None, loc=None, ip=None)¶
Bases:
_ods_ir- OPERATION_NAME = 'shard.grid_shape'¶
- _ODS_REGIONS = (0, True)¶
- grid() _ods_ir¶
- axes() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.grid_shape(result, grid, *, axes=None, loc=None, ip=None) _ods_ir | _ods_ir | GridShapeOp¶
- class mlir.dialects._shard_ops_gen.NeighborsLinearIndicesOp(grid, device, split_axes, *, results=None, loc=None, ip=None)¶
Bases:
_ods_irExample:
shard.grid @grid0(shape = 10x20x30) %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index
The above returns two indices,
633and693, which correspond to the index of the previous process(1, 1, 3), and the next process(1, 3, 3)along the split axis1.A negative value is returned if there is no neighbor in the respective direction along the given
split_axes.- OPERATION_NAME = 'shard.neighbors_linear_indices'¶
- _ODS_REGIONS = (0, True)¶
- device() _ods_ir¶
- grid() _ods_ir¶
- split_axes() _ods_ir¶
- neighbor_down() _ods_ir¶
- neighbor_up() _ods_ir¶
- mlir.dialects._shard_ops_gen.neighbors_linear_indices(grid, device, split_axes, *, results=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ProcessLinearIndexOp(grid, *, results=None, loc=None, ip=None)¶
Bases:
_ods_irExample:
%idx = shard.process_linear_index on @grid : index
if
@gridhas shape(10, 20, 30), a device with multi index(1, 2, 3)will have linear index3 + 30*2 + 20*30*1.- OPERATION_NAME = 'shard.process_linear_index'¶
- _ODS_REGIONS = (0, True)¶
- grid() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.process_linear_index(grid, *, results=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ProcessMultiIndexOp(result, grid, *, axes=None, loc=None, ip=None)¶
Bases:
_ods_irIt is used in the SPMD format of IR. The
axesmush be non-negative and less than the total number of grid axes. If the axes are empty then get the index along all axes.- OPERATION_NAME = 'shard.process_multi_index'¶
- _ODS_REGIONS = (0, True)¶
- grid() _ods_ir¶
- axes() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.process_multi_index(result, grid, *, axes=None, loc=None, ip=None) _ods_ir | _ods_ir | ProcessMultiIndexOp¶
- class mlir.dialects._shard_ops_gen.RecvOp(result, grid, input, source_dynamic, *, grid_axes=None, source=None, loc=None, ip=None)¶
Bases:
_ods_irReceive tensor from device
source, which is defined by its in-group multi-index. The groups are defined bygrid_axes. The content of input tensor is ignored.- OPERATION_NAME = 'shard.recv'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- source_dynamic() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- source() _ods_ir | None¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.recv(result, grid, input, source_dynamic, *, grid_axes=None, source=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ReduceOp(result, grid, input, root, root_dynamic, *, grid_axes=None, reduction=None, loc=None, ip=None)¶
Bases:
_ods_irReduces the input tensor across all devices within the groups defined by
grid_axes, using the specified reduction method. The operation performs an element-wise reduction over the tensor slices from all devices in each group. The reduction result will be returned on therootdevice of each group. It is undefined on all other (non-root) devices. Therootdevice is defined by its in-group multi-index. The accumulation element type is determined by the result type and does not need to match the input element type. Before performing the reduction, each input element is converted to the result element type.Attributes:
reduction: Indicates the reduction method.Example:
%1 = shard.reduce %0 on @grid0 grid_axes = [1, 0] reduction = <max> root = [2, 3] : (tensor<3x4xf32>) -> tensor<3x4xf64>
- OPERATION_NAME = 'shard.reduce'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- root_dynamic() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- reduction() _ods_ir¶
- root() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.reduce(result, grid, input, root, root_dynamic, *, grid_axes=None, reduction=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ReduceScatterOp(result, grid, input, scatter_axis, *, grid_axes=None, reduction=None, loc=None, ip=None)¶
Bases:
_ods_irReduces the input tensor across all devices within the groups defined by
grid_axesusing the specified reduction method. The reduction is performed element-wise across the tensor pieces from all devices in the group. After reduction, the reduction result is scattered (split and distributed) across the device group alongscatter_axis. Example:shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = <max> scatter_axis = 0 : tensor<2x2xf32> -> tensor<1x2xf64>
Input:
device (0, 1) ↓ +-------+-------+ | scatter tensor device (0, 0) -> | 1 2 | 5 6 | | axis 0 | 3 4 | 7 8 | ↓ +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | | 11 12 | 15 16 | +-------+-------+ ↑ device (1, 1)Result:
+-------+ | 5 6 | <- devices (0, 0) +-------+ | 7 8 | <- devices (0, 1) +-------+ | 13 14 | <- devices (1, 0) +-------+ | 15 16 | <- devices (1, 1) +-------+
- OPERATION_NAME = 'shard.reduce_scatter'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- reduction() _ods_ir¶
- scatter_axis() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.reduce_scatter(result, grid, input, scatter_axis, *, grid_axes=None, reduction=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ScatterOp(result, grid, input, scatter_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irFor each device group defined by
grid_axes, the input tensor on therootdevice is split along axisscatter_axisand distributed across the group. The content of the input on all other (non-root) devices is ignored. Therootdevice is defined by its in-group multi-index.Example:
shard.grid @grid0(shape = 2x2) %1 = shard.scatter %0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<2x2xi8>) -> tensor<1x2xi8>
Input:
device (0, 1) ↓ +-------+-------+ | scatter tensor device (0, 0) -> | * * | * * | | axis 0 | * * | * * | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | +-------+-------+ ↑ device (1, 1)Result:
device (0, 1) ↓ +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | +-------+-------+ device (1, 0) -> | 3 4 | 7 8 | +-------+-------+ ↑ device (1, 1)- OPERATION_NAME = 'shard.scatter'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- root_dynamic() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- scatter_axis() _ods_ir¶
- root() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.scatter(result, grid, input, scatter_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.SendOp(result, grid, input, destination, destination_dynamic, *, grid_axes=None, loc=None, ip=None)¶
Bases:
_ods_irSend input tensor to device
destination, which is defined by its in-group multi-index. The groups are defined bygrid_axes.- OPERATION_NAME = 'shard.send'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- destination_dynamic() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- destination() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.send(result, grid, input, destination, destination_dynamic, *, grid_axes=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ShardOp(src, sharding, *, annotate_for_users=None, results=None, loc=None, ip=None)¶
Bases:
_ods_irThe shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes:
#.
input: This operand represents the tensor value that needs to be annotated for sharding. #.sharding: This attribute is type ofShardingType, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by anshard.shardingoperation. #.annotate_for_users: A unit attribute addressing the scenario when a tensor’s sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value.Example:
func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32> ... } func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> ... } func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ... } // The first shard.shard op applies to %arg0, the second shard.shard op // applies for the operand of op0, the third shard.shard op applies for the // operand of op2 func.func @both_result_and_multi_operands_annotated( %arg0 : tensor<4x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32> %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32> %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> "op0"(%1) : ... "op1"(%2) : ... ... }The following usages are undefined:
func.func @annotate_on_same_result_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32> %1 = shard.shard %0 to sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_result_same_value_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32> %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_operand_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> ... } func.func @result_annotated_after_operand( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32> ... }- OPERATION_NAME = 'shard.shard'¶
- _ODS_REGIONS = (0, True)¶
- src() _ods_ir¶
- sharding() _ods_ir¶
- annotate_for_users() bool¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.shard(src, sharding, *, annotate_for_users=None, results=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ShardShapeOp(result, dims, dims_dynamic, sharding, device, device_dynamic, *, loc=None, ip=None)¶
Bases:
_ods_irThe device/process id is a multi-index of the device/process in the shard. This operation might be used during partition when the shard shape depends on (non-constant) values used in
shard.sharding.- OPERATION_NAME = 'shard.shard_shape'¶
- _ODS_OPERAND_SEGMENTS¶
- _ODS_REGIONS = (0, True)¶
- dims_dynamic() _ods_ir¶
- sharding() _ods_ir¶
- device_dynamic() _ods_ir¶
- dims() _ods_ir¶
- device() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.shard_shape(result, dims, dims_dynamic, sharding, device, device_dynamic, *, loc=None, ip=None) _ods_ir | _ods_ir | ShardShapeOp¶
- class mlir.dialects._shard_ops_gen.ShardingOp(grid, split_axes, dynamic_sharded_dims_offsets, dynamic_halo_sizes, *, static_sharded_dims_offsets=None, static_halo_sizes=None, results=None, loc=None, ip=None)¶
Bases:
_ods_irThe Sharding specifies how a tensor is sharded and distributed across the process shard. It is typically used in a
shard.shardoperation. The operation has the following attributes and operands:#.
grid: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to ashard.gridoperation. #.split_axes: is an array composed of int64_t sub-arrays. The outer array’s maximum size is therankof the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor’s i-th dimension is splitted along the x and y axes of the device grid. #. [Optional] Sizes of halos to be added for each sharded tensor dimension.halo_sizesis provided as a flattened 1d array of i64s, 2 values for each sharded dimension.halo_sizes = [1, 2]means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end.halo_sizes = [1, 2, 2, 3]defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets[1,2]halos and the seconds gets[2,3]halos.?indicates dynamic halo sizes. #. [Optional] Offsets for each shard and sharded tensor dimension.sharded_dims_offsetsis provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that positionihas the exclusive prefix sum for shardi, and since only contiguous sharding is supported, its inclusive prefix sum is at position ‘i+1’.Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded,
sharded_dims_offsets= [0, 24, 32, 0, 20, 32] means that the first device of the device-grid will get a shard of shape 24x20x32 and the second device will get a shard of shape 8x12x32.?indicates dynamic shard dimensions.halo_sizesandsharded_dims_offsetsare mutually exclusive.Examples:
shard.grid @grid0(shape = 2x2x4) shard.grid @grid1d_4(shape = 4) // The tensor is fully replicated on @grid0. // Currently, there must be at least one sub-array present in axes, even // if it's empty. Otherwise, a parsing error will occur. %sharding0 = shard.sharding @grid0 split_axes = [[]] // The tensor is sharded on the first dimension along axis 0 of @grid0 %sharding1 = shard.sharding @grid0 split_axes = [[0]] // Could be used for a shard.shard op %sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32> // The tensor is sharded on its first dimension along axis 0 of @grid0 and // and it has halo-sizes of 1 and 2 on the sharded dim. %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32> // The tensor is sharded on its second dimension along axis 0 of @grid1d_4 // and it has pre-defined shard sizes. The shards of the devices will have // the following shapes: [4x2, 4x3, 4x4, 4x5] %sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14] %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32>
- OPERATION_NAME = 'shard.sharding'¶
- _ODS_OPERAND_SEGMENTS¶
- _ODS_REGIONS = (0, True)¶
- dynamic_sharded_dims_offsets() _ods_ir¶
- dynamic_halo_sizes() _ods_ir¶
- grid() _ods_ir¶
- split_axes() _ods_ir¶
- static_sharded_dims_offsets() _ods_ir¶
- static_halo_sizes() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.sharding(grid, split_axes, dynamic_sharded_dims_offsets, dynamic_halo_sizes, *, static_sharded_dims_offsets=None, static_halo_sizes=None, results=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.ShiftOp(result, grid, input, shift_axis, offset, *, grid_axes=None, rotate=None, loc=None, ip=None)¶
Bases:
_ods_irWithin each device group defined by
grid_axes, shifts input tensors along the device grid’s axisshift_axisby the specified offset. Theshift_axismust be one of thegrid_axes. If therotateattribute is set, the shift is circular. That is, the offset wraps around according to the group size alongshift_axis. Otherwise, the results on devices without a corresponding source are undefined.Example:
shard.grid @grid0(shape = 2x4) %1 = shard.shift on @grid0 grid_axes = [1] shift_axis = 1 offset = 2 rotate : tensor<2xi8> -> tensor<2xi8>
Input:
grid axis 1 -----------> +----+----+----+----+ | 1 | 2 | 3 | 4 | +----+----+----+----+ | 5 | 6 | 7 | 8 | +----+----+----+----+
Result:
+----+----+----+----+ | 3 | 4 | 1 | 2 | +----+----+----+----+ | 7 | 8 | 5 | 6 | +----+----+----+----+
- OPERATION_NAME = 'shard.shift'¶
- _ODS_REGIONS = (0, True)¶
- input() _ods_ir¶
- grid() _ods_ir¶
- grid_axes() _ods_ir¶
- shift_axis() _ods_ir¶
- offset() _ods_ir¶
- rotate() bool¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.shift(result, grid, input, shift_axis, offset, *, grid_axes=None, rotate=None, loc=None, ip=None) _ods_ir¶
- class mlir.dialects._shard_ops_gen.UpdateHaloOp(result, destination, grid, split_axes, halo_sizes, *, static_halo_sizes=None, loc=None, ip=None)¶
Bases:
_ods_irThis operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones.
Destination is supposed to be initialized with the local data (not halos).
Assumes all devices hold tensors with same-sized halo data as specified by
source_halo_sizes/static_source_halo_sizesanddestination_halo_sizes/static_destination_halo_sizesin source shard and destination/result shard.split_axesspecifies for each tensor axis along which grid axes its halo data is updated.- OPERATION_NAME = 'shard.update_halo'¶
- _ODS_REGIONS = (0, True)¶
- destination() _ods_ir¶
- halo_sizes() _ods_ir¶
- grid() _ods_ir¶
- split_axes() _ods_ir¶
- static_halo_sizes() _ods_ir¶
- result() _ods_ir¶
Shortcut to get an op result if it has only one (throws an error otherwise).
- mlir.dialects._shard_ops_gen.update_halo(result, destination, grid, split_axes, halo_sizes, *, static_halo_sizes=None, loc=None, ip=None) _ods_ir¶