MLIR  20.0.0git
Public Member Functions | Static Public Member Functions | List of all members
mlir::mesh::MeshSharding Class Reference

#include "mlir/Dialect/Mesh/IR/MeshOps.h"

Public Member Functions

 MeshSharding ()=default
 
 MeshSharding (Value rhs)
 
::mlir::FlatSymbolRefAttr getMeshAttr () const
 
::llvm::StringRef getMesh () const
 
ArrayRef< MeshAxesAttrgetSplitAxes () const
 
ArrayRef< MeshAxisgetPartialAxes () const
 
ReductionKind getPartialType () const
 
ArrayRef< int64_t > getStaticHaloSizes () const
 
ArrayRef< int64_t > getStaticShardedDimsSizes () const
 
ArrayRef< ValuegetDynamicHaloSizes () const
 
ArrayRef< ValuegetDynamicShardedDimsSizes () const
 
 operator bool () const
 
bool operator== (Value rhs) const
 
bool operator!= (Value rhs) const
 
bool operator== (const MeshSharding &rhs) const
 
bool operator!= (const MeshSharding &rhs) const
 
bool equalSplitAndPartialAxes (const MeshSharding &rhs) const
 
bool equalHaloAndShardSizes (const MeshSharding &rhs) const
 

Static Public Member Functions

static MeshSharding get (::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
 

Detailed Description

Definition at line 41 of file MeshOps.h.

Constructor & Destructor Documentation

◆ MeshSharding() [1/2]

mlir::mesh::MeshSharding::MeshSharding ( )
default

◆ MeshSharding() [2/2]

MeshSharding::MeshSharding ( Value  rhs)

Definition at line 625 of file MeshOps.cpp.

References get(), and mlir::Value::getDefiningOp().

Member Function Documentation

◆ equalHaloAndShardSizes()

bool MeshSharding::equalHaloAndShardSizes ( const MeshSharding rhs) const

◆ equalSplitAndPartialAxes()

bool MeshSharding::equalSplitAndPartialAxes ( const MeshSharding rhs) const

Definition at line 553 of file MeshOps.cpp.

References getMesh(), getPartialAxes(), getPartialType(), getSplitAxes(), and min().

Referenced by operator==().

◆ get()

MeshSharding MeshSharding::get ( ::mlir::FlatSymbolRefAttr  mesh_,
ArrayRef< MeshAxesAttr split_axes_,
ArrayRef< MeshAxis partial_axes_ = {},
ReductionKind  partial_type_ = ReductionKind::Sum,
ArrayRef< int64_t >  static_halo_sizes_ = {},
ArrayRef< int64_t >  static_sharded_dims_sizes_ = {},
ArrayRef< Value dynamic_halo_sizes_ = {},
ArrayRef< Value dynamic_sharded_dims_sizes_ = {} 
)
static

◆ getDynamicHaloSizes()

ArrayRef<Value> mlir::mesh::MeshSharding::getDynamicHaloSizes ( ) const
inline

Definition at line 72 of file MeshOps.h.

Referenced by equalHaloAndShardSizes().

◆ getDynamicShardedDimsSizes()

ArrayRef<Value> mlir::mesh::MeshSharding::getDynamicShardedDimsSizes ( ) const
inline

Definition at line 73 of file MeshOps.h.

Referenced by equalHaloAndShardSizes().

◆ getMesh()

::llvm::StringRef mlir::mesh::MeshSharding::getMesh ( ) const
inline

◆ getMeshAttr()

::mlir::FlatSymbolRefAttr mlir::mesh::MeshSharding::getMeshAttr ( ) const
inline

◆ getPartialAxes()

ArrayRef<MeshAxis> mlir::mesh::MeshSharding::getPartialAxes ( ) const
inline

◆ getPartialType()

ReductionKind mlir::mesh::MeshSharding::getPartialType ( ) const
inline

◆ getSplitAxes()

ArrayRef<MeshAxesAttr> mlir::mesh::MeshSharding::getSplitAxes ( ) const
inline

◆ getStaticHaloSizes()

ArrayRef<int64_t> mlir::mesh::MeshSharding::getStaticHaloSizes ( ) const
inline

◆ getStaticShardedDimsSizes()

ArrayRef<int64_t> mlir::mesh::MeshSharding::getStaticShardedDimsSizes ( ) const
inline

◆ operator bool()

mlir::mesh::MeshSharding::operator bool ( ) const
inline

Definition at line 76 of file MeshOps.h.

◆ operator!=() [1/2]

bool MeshSharding::operator!= ( const MeshSharding rhs) const

Definition at line 621 of file MeshOps.cpp.

◆ operator!=() [2/2]

bool MeshSharding::operator!= ( Value  rhs) const

Definition at line 615 of file MeshOps.cpp.

◆ operator==() [1/2]

bool MeshSharding::operator== ( const MeshSharding rhs) const

Definition at line 617 of file MeshOps.cpp.

References equalHaloAndShardSizes(), and equalSplitAndPartialAxes().

◆ operator==() [2/2]

bool MeshSharding::operator== ( Value  rhs) const

Definition at line 611 of file MeshOps.cpp.

References equalHaloAndShardSizes(), and equalSplitAndPartialAxes().


The documentation for this class was generated from the following files: