MLIR
20.0.0git
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
Public Member Functions | |
MeshSharding ()=default | |
MeshSharding (Value rhs) | |
::mlir::FlatSymbolRefAttr | getMeshAttr () const |
::llvm::StringRef | getMesh () const |
ArrayRef< MeshAxesAttr > | getSplitAxes () const |
ArrayRef< MeshAxis > | getPartialAxes () const |
ReductionKind | getPartialType () const |
ArrayRef< int64_t > | getStaticHaloSizes () const |
ArrayRef< int64_t > | getStaticShardedDimsOffsets () const |
ArrayRef< Value > | getDynamicHaloSizes () const |
ArrayRef< Value > | getDynamicShardedDimsOffsets () const |
operator bool () const | |
bool | operator== (Value rhs) const |
bool | operator!= (Value rhs) const |
bool | operator== (const MeshSharding &rhs) const |
bool | operator!= (const MeshSharding &rhs) const |
bool | equalSplitAndPartialAxes (const MeshSharding &rhs) const |
bool | equalHaloAndShardSizes (const MeshSharding &rhs) const |
bool | equalHaloSizes (const MeshSharding &rhs) const |
bool | equalShardSizes (const MeshSharding &rhs) const |
Static Public Member Functions | |
static MeshSharding | get (::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={}) |
|
default |
MeshSharding::MeshSharding | ( | Value | rhs | ) |
Definition at line 710 of file MeshOps.cpp.
References get(), and mlir::Value::getDefiningOp().
bool MeshSharding::equalHaloAndShardSizes | ( | const MeshSharding & | rhs | ) | const |
Definition at line 653 of file MeshOps.cpp.
References equalHaloSizes(), and equalShardSizes().
Referenced by operator==().
bool MeshSharding::equalHaloSizes | ( | const MeshSharding & | rhs | ) | const |
Definition at line 678 of file MeshOps.cpp.
References getDynamicHaloSizes(), and getStaticHaloSizes().
Referenced by equalHaloAndShardSizes(), and mlir::mesh::tryUpdateHaloInResharding().
bool MeshSharding::equalShardSizes | ( | const MeshSharding & | rhs | ) | const |
Definition at line 657 of file MeshOps.cpp.
References getDynamicShardedDimsOffsets(), and getStaticShardedDimsOffsets().
Referenced by equalHaloAndShardSizes().
bool MeshSharding::equalSplitAndPartialAxes | ( | const MeshSharding & | rhs | ) | const |
Definition at line 623 of file MeshOps.cpp.
References getMesh(), getPartialAxes(), getPartialType(), getSplitAxes(), and min().
Referenced by operator==(), and mlir::mesh::tryUpdateHaloInResharding().
|
static |
Definition at line 722 of file MeshOps.cpp.
References mlir::clone(), copy(), mlir::detail::enumerate(), and mlir::detail::DenseArrayAttrImpl< T >::get().
Referenced by getSharding(), mlir::mesh::handlePartialAxesDuringResharding(), MeshSharding(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
Definition at line 73 of file MeshOps.h.
Referenced by equalHaloSizes(), and mlir::mesh::tryUpdateHaloInResharding().
Definition at line 74 of file MeshOps.h.
Referenced by equalShardSizes().
|
inline |
Definition at line 65 of file MeshOps.h.
References mlir::FlatSymbolRefAttr::getValue().
Referenced by mlir::linalg::createAllReduceForResultWithoutPartialSharding(), and equalSplitAndPartialAxes().
|
inline |
Definition at line 67 of file MeshOps.h.
Referenced by mlir::linalg::createAllReduceForResultWithoutPartialSharding(), mlir::mesh::detail::defaultGetShardingOption(), equalSplitAndPartialAxes(), mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::isFullReplication(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), mlir::mesh::targetShardingInUnsplitLastAxis(), and mlir::mesh::tryUpdateHaloInResharding().
|
inline |
Definition at line 68 of file MeshOps.h.
Referenced by equalSplitAndPartialAxes(), mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
|
inline |
Definition at line 66 of file MeshOps.h.
Referenced by mlir::mesh::detail::defaultGetShardingOption(), mlir::mesh::detectMoveLastSplitAxisInResharding(), mlir::mesh::detectSplitLastAxisInResharding(), mlir::mesh::detectUnsplitLastAxisInResharding(), equalSplitAndPartialAxes(), mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::isFullReplication(), mlir::mesh::shardShapedType(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), mlir::mesh::targetShardingInUnsplitLastAxis(), and mlir::mesh::tryUpdateHaloInResharding().
|
inline |
Definition at line 69 of file MeshOps.h.
Referenced by equalHaloSizes(), mlir::mesh::reshardOn1DMesh(), mlir::mesh::shardShapedType(), and mlir::mesh::tryUpdateHaloInResharding().
|
inline |
Definition at line 70 of file MeshOps.h.
Referenced by equalShardSizes(), mlir::mesh::reshardOn1DMesh(), mlir::mesh::shardShapedType(), and mlir::mesh::tryUpdateHaloInResharding().
bool MeshSharding::operator!= | ( | const MeshSharding & | rhs | ) | const |
Definition at line 706 of file MeshOps.cpp.
bool MeshSharding::operator!= | ( | Value | rhs | ) | const |
Definition at line 700 of file MeshOps.cpp.
bool MeshSharding::operator== | ( | const MeshSharding & | rhs | ) | const |
Definition at line 702 of file MeshOps.cpp.
References equalHaloAndShardSizes(), and equalSplitAndPartialAxes().
bool MeshSharding::operator== | ( | Value | rhs | ) | const |
Definition at line 696 of file MeshOps.cpp.
References equalHaloAndShardSizes(), and equalSplitAndPartialAxes().