mlir.dialects._shard_ops_gen ============================ .. py:module:: mlir.dialects._shard_ops_gen Attributes ---------- .. autoapisummary:: mlir.dialects._shard_ops_gen._ods_ir Classes ------- .. autoapisummary:: mlir.dialects._shard_ops_gen._Dialect mlir.dialects._shard_ops_gen.AllGatherOp mlir.dialects._shard_ops_gen.AllReduceOp mlir.dialects._shard_ops_gen.AllSliceOp mlir.dialects._shard_ops_gen.AllToAllOp mlir.dialects._shard_ops_gen.BroadcastOp mlir.dialects._shard_ops_gen.GatherOp mlir.dialects._shard_ops_gen.GetShardingOp mlir.dialects._shard_ops_gen.GridOp mlir.dialects._shard_ops_gen.GridShapeOp mlir.dialects._shard_ops_gen.NeighborsLinearIndicesOp mlir.dialects._shard_ops_gen.ProcessLinearIndexOp mlir.dialects._shard_ops_gen.ProcessMultiIndexOp mlir.dialects._shard_ops_gen.RecvOp mlir.dialects._shard_ops_gen.ReduceOp mlir.dialects._shard_ops_gen.ReduceScatterOp mlir.dialects._shard_ops_gen.ScatterOp mlir.dialects._shard_ops_gen.SendOp mlir.dialects._shard_ops_gen.ShardOp mlir.dialects._shard_ops_gen.ShardShapeOp mlir.dialects._shard_ops_gen.ShardingOp mlir.dialects._shard_ops_gen.ShiftOp mlir.dialects._shard_ops_gen.UpdateHaloOp Functions --------- .. autoapisummary:: mlir.dialects._shard_ops_gen.all_gather mlir.dialects._shard_ops_gen.all_reduce mlir.dialects._shard_ops_gen.all_slice mlir.dialects._shard_ops_gen.all_to_all mlir.dialects._shard_ops_gen.broadcast mlir.dialects._shard_ops_gen.gather mlir.dialects._shard_ops_gen.get_sharding mlir.dialects._shard_ops_gen.grid mlir.dialects._shard_ops_gen.grid_shape mlir.dialects._shard_ops_gen.neighbors_linear_indices mlir.dialects._shard_ops_gen.process_linear_index mlir.dialects._shard_ops_gen.process_multi_index mlir.dialects._shard_ops_gen.recv mlir.dialects._shard_ops_gen.reduce mlir.dialects._shard_ops_gen.reduce_scatter mlir.dialects._shard_ops_gen.scatter mlir.dialects._shard_ops_gen.send mlir.dialects._shard_ops_gen.shard mlir.dialects._shard_ops_gen.shard_shape mlir.dialects._shard_ops_gen.sharding mlir.dialects._shard_ops_gen.shift mlir.dialects._shard_ops_gen.update_halo Module Contents --------------- .. py:data:: _ods_ir .. py:class:: _Dialect(descriptor: object) Bases: :py:obj:`_ods_ir` .. py:attribute:: DIALECT_NAMESPACE :value: 'shard' .. py:class:: AllGatherOp(result, grid, input, gather_axis, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Concatenates all tensor slices from a device group defined by ``grid_axes`` along the tensor dimension ``gather_axis`` and replicates the result across all devices in the group. Example: .. code:: mlir shard.grid @grid0(shape = 2x2) ... %1 = shard.all_gather %0 on @grid0 grid_axes = [1] gather_axis = 1 : tensor<2x2xi8> -> tensor<2x4xi8> Input: .. code:: +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+ Result: .. code:: gather tensor axis 1 ------------> +-------------+ | 1 2 5 6 | <- devices (0, 0) and (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 0) and (1, 1) | 11 12 15 16 | +-------------+ .. py:attribute:: OPERATION_NAME :value: 'shard.all_gather' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: gather_axis() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: all_gather(result, grid, input, gather_axis, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: AllReduceOp(result, grid, input, *, grid_axes=None, reduction=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Reduces the input tensor across all devices within the groups defined by ``grid_axes``, using the specified reduction method. The operation performs an element-wise reduction over the tensor slices from all devices in each group. Each device in a group receives a replicated copy of the reduction result. The accumulation element type is determined by the result type and does not need to match the input element type. Before performing the reduction, each input element is converted to the result element type. Attributes: ``reduction``: Indicates the reduction method. Example: .. code:: %1 = shard.all_reduce %0 on @grid0 grid_axes = [1, 0] reduction = : tensor<3x4xf32> -> tensor<3x4xf64> .. py:attribute:: OPERATION_NAME :value: 'shard.all_reduce' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: reduction() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: all_reduce(result, grid, input, *, grid_axes=None, reduction=None, loc=None, ip=None) -> _ods_ir .. py:class:: AllSliceOp(result, grid, input, slice_axis, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Within each device group defined by ``grid_axes``, slices the input tensor along the ``slice_axis`` dimension. It can be viewed as the inverse of an all-gather if the input data is replicated along the ``slice_axis``. Each process simply crops its local data to the slice corresponding to its in-group device index. Notice: ``AllSliceOp`` does not involve any communication between devices and devices within a group may not have replicated input data. Example: .. code:: mlir shard.grid @grid0(shape = 2x2) ... %1 = shard.all_slice %0 on @grid0 grid_axes = [1] slice_axis = 1 : tensor<2x4xi8> -> tensor<2x2xi8> Input: .. code:: +-------------+ | 1 2 5 6 | <- devices (0, 0) and (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 0) and (1, 1) | 11 12 15 16 | +-------------+ Result: .. code:: slice tensor axis 1 ------------> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+ .. py:attribute:: OPERATION_NAME :value: 'shard.all_slice' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: slice_axis() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: all_slice(result, grid, input, slice_axis, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: AllToAllOp(result, grid, input, split_axis, concat_axis, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Each participant logically splits its input along split_axis, then scatters the resulting pieces across the group defined by ``grid_axes``. After receiving data pieces from other participants' scatters, it concatenates them along concat_axis to produce the final result. Example: .. code:: shard.grid @grid0(shape = 3) ... %1 = shard.all_to_all %0 on @grid0 grid_axes = [0] split_axis = 0 concat_axis = 0 : tensor<3x2xi8> -> tensor<3x2xi8> Input: .. code:: device device device (0) (1) (2) +-------+-------+-------+ | split and concat along | 11 12 | 21 22 | 31 32 | | tensor axis 0 | 13 14 | 23 24 | 33 34 | ↓ | 15 16 | 25 26 | 35 36 | +-------+-------+-------+ Result: .. code:: device device device (0) (1) (2) +-------+-------+-------+ | 11 12 | 13 14 | 15 16 | | 21 22 | 23 24 | 25 26 | | 31 32 | 33 34 | 35 36 | +-------+-------+-------+ .. py:attribute:: OPERATION_NAME :value: 'shard.all_to_all' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: split_axis() -> _ods_ir .. py:method:: concat_axis() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: all_to_all(result, grid, input, split_axis, concat_axis, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: BroadcastOp(result, grid, input, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Copies the input tensor on ``root`` to all devices in each group defined by ``grid_axes``. The ``root`` device is defined by its in-group multi-index. The contents of input tensors on non-root devices are ignored. Example: .. code:: shard.grid @grid0(shape = 2x2) %1 = shard.broadcast %0 on @grid0 grid_axes = [0] root = [0] : (tensor<2xi8>) -> tensor<2xi8> Input: .. code:: +-------+-------+ | broadcast device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) | along axis 0 +-------+-------+ ↓ device (1, 0) -> | * * | * * | <- device (1, 1) +-------+-------+ Output: .. code:: +-------+-------+ device (0, 0) -> | 1 2 | 3 4 | <- device (0, 1) +-------+-------+ device (1, 0) -> | 1 2 | 3 4 | <- device (1, 1) +-------+-------+ .. py:attribute:: OPERATION_NAME :value: 'shard.broadcast' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: root_dynamic() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: root() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: broadcast(result, grid, input, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: GatherOp(result, grid, input, gather_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Concatenates all tensor slices from a device group defined by ``grid_axes`` along the tensor dimension ``gather_axis`` and returns the resulting tensor on each ``root`` device. The result on all other (non-root) devices is undefined. The ``root`` device is defined by its in-group multi-index. Example: .. code:: mlir shard.grid @grid0(shape = 2x2) ... %1 = shard.gather %0 on @grid0 grid_axes = [1] gather_axis = 1 root = [1] : (tensor<2x2xi8>) -> tensor<2x4xi8> Input: .. code:: gather tensor axis 1 ------------> +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | <- device (0, 1) | 3 4 | 7 8 | +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | <- device (1, 1) | 11 12 | 15 16 | +-------+-------+ Result: .. code:: +-------------+ | 1 2 5 6 | <- devices (0, 1) | 3 4 7 8 | +-------------+ | 9 10 13 14 | <- devices (1, 1) | 11 12 15 16 | +-------------+ Devices ``(0, 0)`` and ``(1, 0)`` have undefined result. .. py:attribute:: OPERATION_NAME :value: 'shard.gather' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: root_dynamic() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: gather_axis() -> _ods_ir .. py:method:: root() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: gather(result, grid, input, gather_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: GetShardingOp(source, *, results=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` This operation returns the sharding of the given tensor as a Sharding. .. py:attribute:: OPERATION_NAME :value: 'shard.get_sharding' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: source() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: get_sharding(source, *, results=None, loc=None, ip=None) -> _ods_ir .. py:class:: GridOp(sym_name, shape, *, loc=None, ip=None) Bases: :py:obj:`_ods_ir` The shard.grid operation is a symbol operation that identifies a specific grid. The operation has three attributes: #. ``sym_name``: This attribute uniquely identifies the name of the grid. This name serves as a symbolic reference to the grid throughout the MLIR module, allowing for consistent referencing and easier debugging. #. ``shape``: This attribute represents the shape of the device grid. It uses the same notation as a tensor shape. Also allowing for dynamic dimensions. This flexibility allows for dynamic device assignment or configurations where the exact number of devices might not be determined during compile time. For example ``2x?x4``. Example: .. code:: // A device grid with 3 axes, the total device number is 4 * 8 * 12 // The dimension sizes are 4, 8, 12 shard.grid @grid0(shape = 4x8x12) // A device grid with 2 axes, the total device number is unknown // The first dimension size is 4 and the second is unknown shard.grid @grid1(shape = 4x?) // A device grid with 2 axes, the total device number is unknown // The first dimension size is unknown and the second is 4 shard.grid @grid2(shape = ?x4) // A device grid with 2 axes, the number of devices along both axes // is unknown shard.grid @grid3(shape = ?x?) .. py:attribute:: OPERATION_NAME :value: 'shard.grid' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: sym_name() -> _ods_ir .. py:method:: shape() -> _ods_ir .. py:function:: grid(sym_name, shape, *, loc=None, ip=None) -> GridOp .. py:class:: GridShapeOp(result, grid, *, axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` .. py:attribute:: OPERATION_NAME :value: 'shard.grid_shape' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: grid() -> _ods_ir .. py:method:: axes() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: grid_shape(result, grid, *, axes=None, loc=None, ip=None) -> Union[_ods_ir, _ods_ir, GridShapeOp] .. py:class:: NeighborsLinearIndicesOp(grid, device, split_axes, *, results=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Example: .. code:: shard.grid @grid0(shape = 10x20x30) %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index %idx = shard.neighbors_linear_indices on @grid[%c1, %c2, %c3] split_axes = [1] : index The above returns two indices, ``633`` and ``693``, which correspond to the index of the previous process ``(1, 1, 3)``, and the next process ``(1, 3, 3)`` along the split axis ``1``. A negative value is returned if there is no neighbor in the respective direction along the given ``split_axes``. .. py:attribute:: OPERATION_NAME :value: 'shard.neighbors_linear_indices' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: device() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: split_axes() -> _ods_ir .. py:method:: neighbor_down() -> _ods_ir .. py:method:: neighbor_up() -> _ods_ir .. py:function:: neighbors_linear_indices(grid, device, split_axes, *, results=None, loc=None, ip=None) -> _ods_ir .. py:class:: ProcessLinearIndexOp(grid, *, results=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Example: .. code:: %idx = shard.process_linear_index on @grid : index if ``@grid`` has shape ``(10, 20, 30)``, a device with multi index ``(1, 2, 3)`` will have linear index ``3 + 30*2 + 20*30*1``. .. py:attribute:: OPERATION_NAME :value: 'shard.process_linear_index' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: grid() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: process_linear_index(grid, *, results=None, loc=None, ip=None) -> _ods_ir .. py:class:: ProcessMultiIndexOp(result, grid, *, axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` It is used in the SPMD format of IR. The ``axes`` mush be non-negative and less than the total number of grid axes. If the axes are empty then get the index along all axes. .. py:attribute:: OPERATION_NAME :value: 'shard.process_multi_index' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: grid() -> _ods_ir .. py:method:: axes() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: process_multi_index(result, grid, *, axes=None, loc=None, ip=None) -> Union[_ods_ir, _ods_ir, ProcessMultiIndexOp] .. py:class:: RecvOp(result, grid, input, source_dynamic, *, grid_axes=None, source=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Receive tensor from device ``source``, which is defined by its in-group multi-index. The groups are defined by ``grid_axes``. The content of input tensor is ignored. .. py:attribute:: OPERATION_NAME :value: 'shard.recv' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: source_dynamic() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: source() -> Optional[_ods_ir] .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: recv(result, grid, input, source_dynamic, *, grid_axes=None, source=None, loc=None, ip=None) -> _ods_ir .. py:class:: ReduceOp(result, grid, input, root, root_dynamic, *, grid_axes=None, reduction=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Reduces the input tensor across all devices within the groups defined by ``grid_axes``, using the specified reduction method. The operation performs an element-wise reduction over the tensor slices from all devices in each group. The reduction result will be returned on the ``root`` device of each group. It is undefined on all other (non-root) devices. The ``root`` device is defined by its in-group multi-index. The accumulation element type is determined by the result type and does not need to match the input element type. Before performing the reduction, each input element is converted to the result element type. Attributes: ``reduction``: Indicates the reduction method. Example: .. code:: %1 = shard.reduce %0 on @grid0 grid_axes = [1, 0] reduction = root = [2, 3] : (tensor<3x4xf32>) -> tensor<3x4xf64> .. py:attribute:: OPERATION_NAME :value: 'shard.reduce' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: root_dynamic() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: reduction() -> _ods_ir .. py:method:: root() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: reduce(result, grid, input, root, root_dynamic, *, grid_axes=None, reduction=None, loc=None, ip=None) -> _ods_ir .. py:class:: ReduceScatterOp(result, grid, input, scatter_axis, *, grid_axes=None, reduction=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Reduces the input tensor across all devices within the groups defined by ``grid_axes`` using the specified reduction method. The reduction is performed element-wise across the tensor pieces from all devices in the group. After reduction, the reduction result is scattered (split and distributed) across the device group along ``scatter_axis``. Example: .. code:: shard.grid @grid0(shape = 2x2) ... %1 = shard.reduce_scatter %0 on @grid0 grid_axes = [1] reduction = scatter_axis = 0 : tensor<2x2xf32> -> tensor<1x2xf64> Input: .. code:: device (0, 1) ↓ +-------+-------+ | scatter tensor device (0, 0) -> | 1 2 | 5 6 | | axis 0 | 3 4 | 7 8 | ↓ +-------+-------+ device (1, 0) -> | 9 10 | 13 14 | | 11 12 | 15 16 | +-------+-------+ ↑ device (1, 1) Result: .. code:: +-------+ | 5 6 | <- devices (0, 0) +-------+ | 7 8 | <- devices (0, 1) +-------+ | 13 14 | <- devices (1, 0) +-------+ | 15 16 | <- devices (1, 1) +-------+ .. py:attribute:: OPERATION_NAME :value: 'shard.reduce_scatter' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: reduction() -> _ods_ir .. py:method:: scatter_axis() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: reduce_scatter(result, grid, input, scatter_axis, *, grid_axes=None, reduction=None, loc=None, ip=None) -> _ods_ir .. py:class:: ScatterOp(result, grid, input, scatter_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` For each device group defined by ``grid_axes``, the input tensor on the ``root`` device is split along axis ``scatter_axis`` and distributed across the group. The content of the input on all other (non-root) devices is ignored. The ``root`` device is defined by its in-group multi-index. Example: .. code:: shard.grid @grid0(shape = 2x2) %1 = shard.scatter %0 on @grid0 grid_axes = [0] scatter_axis = 0 root = [1] : (tensor<2x2xi8>) -> tensor<1x2xi8> Input: .. code:: device (0, 1) ↓ +-------+-------+ | scatter tensor device (0, 0) -> | * * | * * | | axis 0 | * * | * * | ↓ +-------+-------+ device (1, 0) -> | 1 2 | 5 6 | | 3 4 | 7 8 | +-------+-------+ ↑ device (1, 1) Result: .. code:: device (0, 1) ↓ +-------+-------+ device (0, 0) -> | 1 2 | 5 6 | +-------+-------+ device (1, 0) -> | 3 4 | 7 8 | +-------+-------+ ↑ device (1, 1) .. py:attribute:: OPERATION_NAME :value: 'shard.scatter' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: root_dynamic() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: scatter_axis() -> _ods_ir .. py:method:: root() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: scatter(result, grid, input, scatter_axis, root, root_dynamic, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: SendOp(result, grid, input, destination, destination_dynamic, *, grid_axes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Send input tensor to device ``destination``, which is defined by its in-group multi-index. The groups are defined by ``grid_axes``. .. py:attribute:: OPERATION_NAME :value: 'shard.send' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: destination_dynamic() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: destination() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: send(result, grid, input, destination, destination_dynamic, *, grid_axes=None, loc=None, ip=None) -> _ods_ir .. py:class:: ShardOp(src, sharding, *, annotate_for_users=None, results=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` The shard.shard operation is designed to specify and guide the sharding behavior of a tensor value across a grid topology. This operation has two operands and two optional attributes: #. ``input``: This operand represents the tensor value that needs to be annotated for sharding. #. ``sharding``: This attribute is type of ``ShardingType``, which is the core data structure to represent distribution of a tensor on a shard. it is typically defined by an ``shard.sharding`` operation. #. ``annotate_for_users``: A unit attribute addressing the scenario when a tensor's sharding annotation differs based on its context of use (either as a result or an operand). If specified, the sharding pertains to specific users of the tensor value, indicating how it should be considered when used as an operand in subsequent operations. If not, the sharding applies to the operation that defines the tensor value. Example: .. code:: func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32> ... } func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> ... } func.func @two_operands_annotated(%arg0 : tensor<4x8xf32>, %arg1 : tensor<16x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding annotate_for_users : tensor<4x8xf32> %1 = shard.shard %arg1 to %sharding annotate_for_users : tensor<16x8xf32> ... } // The first shard.shard op applies to %arg0, the second shard.shard op // applies for the operand of op0, the third shard.shard op applies for the // operand of op2 func.func @both_result_and_multi_operands_annotated( %arg0 : tensor<4x8xf32>) -> () { %sharding = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %0 = shard.shard %arg0 to %sharding : tensor<4x8xf32> %sharding1 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %1 = shard.shard %0 to %sharding1 annotate_for_users : tensor<4x8xf32> %sharding2 = shard.sharding @grid0 split_axes = [[2]] : !shard.sharding %2 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> "op0"(%1) : ... "op1"(%2) : ... ... } The following usages are undefined: .. code:: func.func @annotate_on_same_result_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to $sharding1 : tensor<4x8xf32> %1 = shard.shard %0 to sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_result_same_value_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to %sharding1 : tensor<4x8xf32> %1 = shard.shard %arg0 to %sharding2 : tensor<4x8xf32> ... } func.func @annotate_on_same_operand_with_different_sharding( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> %1 = shard.shard %0 to %sharding2 annotate_for_users : tensor<4x8xf32> ... } func.func @result_annotated_after_operand( %arg0 : tensor<4x8xf32>) -> () { %sharding1 = shard.sharding @grid0 split_axes = [[0]] : !shard.sharding %sharding2 = shard.sharding @grid0 split_axes = [[1]] : !shard.sharding %0 = shard.shard %arg0 to %sharding1 annotate_for_users : tensor<4x8xf32> %1 = shard.shard %0 to %sharding2 : tensor<4x8xf32> ... } .. py:attribute:: OPERATION_NAME :value: 'shard.shard' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: src() -> _ods_ir .. py:method:: sharding() -> _ods_ir .. py:method:: annotate_for_users() -> bool .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: shard(src, sharding, *, annotate_for_users=None, results=None, loc=None, ip=None) -> _ods_ir .. py:class:: ShardShapeOp(result, dims, dims_dynamic, sharding, device, device_dynamic, *, loc=None, ip=None) Bases: :py:obj:`_ods_ir` The device/process id is a multi-index of the device/process in the shard. This operation might be used during partition when the shard shape depends on (non-constant) values used in ``shard.sharding``. .. py:attribute:: OPERATION_NAME :value: 'shard.shard_shape' .. py:attribute:: _ODS_OPERAND_SEGMENTS .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: dims_dynamic() -> _ods_ir .. py:method:: sharding() -> _ods_ir .. py:method:: device_dynamic() -> _ods_ir .. py:method:: dims() -> _ods_ir .. py:method:: device() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: shard_shape(result, dims, dims_dynamic, sharding, device, device_dynamic, *, loc=None, ip=None) -> Union[_ods_ir, _ods_ir, ShardShapeOp] .. py:class:: ShardingOp(grid, split_axes, dynamic_sharded_dims_offsets, dynamic_halo_sizes, *, static_sharded_dims_offsets=None, static_halo_sizes=None, results=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` The Sharding specifies how a tensor is sharded and distributed across the process shard. It is typically used in a ``shard.shard`` operation. The operation has the following attributes and operands: #. ``grid``: this attribute is a FlatSymbolRefAttr that refers to the device grid where the distributed tensor is placed. The symbol must resolve to a ``shard.grid`` operation. #. ``split_axes``: is an array composed of int64_t sub-arrays. The outer array's maximum size is the ``rank`` of the related tensor. For the i-th sub-array, if its value is [x, y], it indicates that the tensor's i-th dimension is splitted along the x and y axes of the device grid. #. [Optional] Sizes of halos to be added for each sharded tensor dimension. ``halo_sizes`` is provided as a flattened 1d array of i64s, 2 values for each sharded dimension. ``halo_sizes = [1, 2]`` means that the first sharded dimension gets an additional halo of size 1 at the start of the first dimension and a halo size is 2 at its end. ``halo_sizes = [1, 2, 2, 3]`` defines halos for the first 2 sharded dimensions e.g. the first sharded dimension gets ``[1,2]`` halos and the seconds gets ``[2,3]`` halos. ``?`` indicates dynamic halo sizes. #. [Optional] Offsets for each shard and sharded tensor dimension. ``sharded_dims_offsets`` is provided as a flattened 1d array of i64s. For each sharded tensor dimension the offsets (starting index) of all shards in that dimension and an additional value for the end of the last shard are provided. For a 1d sharding this means that position ``i`` has the exclusive prefix sum for shard ``i``, and since only contiguous sharding is supported, its inclusive prefix sum is at position 'i+1'. Assuming a 3d-tensor of shape 32x32x32 with the first 2 dimensions being sharded, ``sharded_dims_offsets`` = [0, 24, 32, 0, 20, 32] means that the first device of the device-grid will get a shard of shape 24x20x32 and the second device will get a shard of shape 8x12x32. ``?`` indicates dynamic shard dimensions. ``halo_sizes`` and ``sharded_dims_offsets`` are mutually exclusive. Examples: .. code:: shard.grid @grid0(shape = 2x2x4) shard.grid @grid1d_4(shape = 4) // The tensor is fully replicated on @grid0. // Currently, there must be at least one sub-array present in axes, even // if it's empty. Otherwise, a parsing error will occur. %sharding0 = shard.sharding @grid0 split_axes = [[]] // The tensor is sharded on the first dimension along axis 0 of @grid0 %sharding1 = shard.sharding @grid0 split_axes = [[0]] // Could be used for a shard.shard op %sharded0 = shard.shard %arg0 to %sharding3 : tensor<4x8xf32> // The tensor is sharded on its first dimension along axis 0 of @grid0 and // and it has halo-sizes of 1 and 2 on the sharded dim. %halo_sharding = shard.sharding @grid0 split_axes = [[0]] halo_sizes = [1, 2] %sharded1 = shard.shard %arg0 to %halo_sharding : tensor<4x8xf32> // The tensor is sharded on its second dimension along axis 0 of @grid1d_4 // and it has pre-defined shard sizes. The shards of the devices will have // the following shapes: [4x2, 4x3, 4x4, 4x5] %sharding4 = shard.sharding @grid1d_4 split_axes = [[], [0]] sharded_dims_offsets = [0, 2, 5, 9, 14] %sharded2 = shard.shard %arg0 to %sharding4 : tensor<4x14xf32> .. py:attribute:: OPERATION_NAME :value: 'shard.sharding' .. py:attribute:: _ODS_OPERAND_SEGMENTS .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: dynamic_sharded_dims_offsets() -> _ods_ir .. py:method:: dynamic_halo_sizes() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: split_axes() -> _ods_ir .. py:method:: static_sharded_dims_offsets() -> _ods_ir .. py:method:: static_halo_sizes() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: sharding(grid, split_axes, dynamic_sharded_dims_offsets, dynamic_halo_sizes, *, static_sharded_dims_offsets=None, static_halo_sizes=None, results=None, loc=None, ip=None) -> _ods_ir .. py:class:: ShiftOp(result, grid, input, shift_axis, offset, *, grid_axes=None, rotate=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` Within each device group defined by ``grid_axes``, shifts input tensors along the device grid's axis ``shift_axis`` by the specified offset. The ``shift_axis`` must be one of the ``grid_axes``. If the ``rotate`` attribute is set, the shift is circular. That is, the offset wraps around according to the group size along ``shift_axis``. Otherwise, the results on devices without a corresponding source are undefined. Example: .. code:: shard.grid @grid0(shape = 2x4) %1 = shard.shift on @grid0 grid_axes = [1] shift_axis = 1 offset = 2 rotate : tensor<2xi8> -> tensor<2xi8> Input: .. code:: grid axis 1 -----------> +----+----+----+----+ | 1 | 2 | 3 | 4 | +----+----+----+----+ | 5 | 6 | 7 | 8 | +----+----+----+----+ Result: .. code:: +----+----+----+----+ | 3 | 4 | 1 | 2 | +----+----+----+----+ | 7 | 8 | 5 | 6 | +----+----+----+----+ .. py:attribute:: OPERATION_NAME :value: 'shard.shift' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: input() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: grid_axes() -> _ods_ir .. py:method:: shift_axis() -> _ods_ir .. py:method:: offset() -> _ods_ir .. py:method:: rotate() -> bool .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: shift(result, grid, input, shift_axis, offset, *, grid_axes=None, rotate=None, loc=None, ip=None) -> _ods_ir .. py:class:: UpdateHaloOp(result, destination, grid, split_axes, halo_sizes, *, static_halo_sizes=None, loc=None, ip=None) Bases: :py:obj:`_ods_ir` This operation updates halo regions of shards, e.g. if their sharding specified halos and the actual tensor/memref data might have changed on the remote devices. Changes might be caused by mutating operations and/or if the new halo regions are larger than the existing ones. Destination is supposed to be initialized with the local data (not halos). Assumes all devices hold tensors with same-sized halo data as specified by ``source_halo_sizes/static_source_halo_sizes`` and ``destination_halo_sizes/static_destination_halo_sizes`` in source shard and destination/result shard. ``split_axes`` specifies for each tensor axis along which grid axes its halo data is updated. .. py:attribute:: OPERATION_NAME :value: 'shard.update_halo' .. py:attribute:: _ODS_REGIONS :value: (0, True) .. py:method:: destination() -> _ods_ir .. py:method:: halo_sizes() -> _ods_ir .. py:method:: grid() -> _ods_ir .. py:method:: split_axes() -> _ods_ir .. py:method:: static_halo_sizes() -> _ods_ir .. py:method:: result() -> _ods_ir Shortcut to get an op result if it has only one (throws an error otherwise). .. py:function:: update_halo(result, destination, grid, split_axes, halo_sizes, *, static_halo_sizes=None, loc=None, ip=None) -> _ods_ir