MLIR
20.0.0git
|
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
Public Member Functions | |
MeshSharding ()=default | |
MeshSharding (Value rhs) | |
::mlir::FlatSymbolRefAttr | getMeshAttr () const |
::llvm::StringRef | getMesh () const |
ArrayRef< MeshAxesAttr > | getSplitAxes () const |
ArrayRef< MeshAxis > | getPartialAxes () const |
ReductionKind | getPartialType () const |
ArrayRef< int64_t > | getStaticHaloSizes () const |
ArrayRef< int64_t > | getStaticShardedDimsSizes () const |
ArrayRef< Value > | getDynamicHaloSizes () const |
ArrayRef< Value > | getDynamicShardedDimsSizes () const |
operator bool () const | |
bool | operator== (Value rhs) const |
bool | operator!= (Value rhs) const |
bool | operator== (const MeshSharding &rhs) const |
bool | operator!= (const MeshSharding &rhs) const |
bool | equalSplitAndPartialAxes (const MeshSharding &rhs) const |
bool | equalHaloAndShardSizes (const MeshSharding &rhs) const |
Static Public Member Functions | |
static MeshSharding | get (::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={}) |
|
default |
MeshSharding::MeshSharding | ( | Value | rhs | ) |
Definition at line 625 of file MeshOps.cpp.
References get(), and mlir::Value::getDefiningOp().
bool MeshSharding::equalHaloAndShardSizes | ( | const MeshSharding & | rhs | ) | const |
Definition at line 578 of file MeshOps.cpp.
References getDynamicHaloSizes(), getDynamicShardedDimsSizes(), getStaticHaloSizes(), and getStaticShardedDimsSizes().
Referenced by operator==().
bool MeshSharding::equalSplitAndPartialAxes | ( | const MeshSharding & | rhs | ) | const |
Definition at line 553 of file MeshOps.cpp.
References getMesh(), getPartialAxes(), getPartialType(), getSplitAxes(), and min().
Referenced by operator==().
|
static |
Definition at line 637 of file MeshOps.cpp.
References mlir::clone(), copy(), mlir::detail::enumerate(), and mlir::detail::DenseArrayAttrImpl< T >::get().
Referenced by getSharding(), mlir::mesh::handlePartialAxesDuringResharding(), MeshSharding(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
Definition at line 72 of file MeshOps.h.
Referenced by equalHaloAndShardSizes().
Definition at line 73 of file MeshOps.h.
Referenced by equalHaloAndShardSizes().
|
inline |
Definition at line 64 of file MeshOps.h.
References mlir::FlatSymbolRefAttr::getValue().
Referenced by mlir::linalg::createAllReduceForResultWithoutPartialSharding(), and equalSplitAndPartialAxes().
|
inline |
Definition at line 63 of file MeshOps.h.
Referenced by mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
Definition at line 66 of file MeshOps.h.
Referenced by mlir::linalg::createAllReduceForResultWithoutPartialSharding(), equalSplitAndPartialAxes(), mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::isFullReplication(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
|
inline |
Definition at line 67 of file MeshOps.h.
Referenced by equalSplitAndPartialAxes(), mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
|
inline |
Definition at line 65 of file MeshOps.h.
Referenced by mlir::mesh::detectMoveLastSplitAxisInResharding(), mlir::mesh::detectSplitLastAxisInResharding(), mlir::mesh::detectUnsplitLastAxisInResharding(), equalSplitAndPartialAxes(), mlir::mesh::handlePartialAxesDuringResharding(), mlir::mesh::isFullReplication(), mlir::mesh::shardShapedType(), mlir::mesh::targetShardingInMoveLastAxis(), mlir::mesh::targetShardingInSplitLastAxis(), and mlir::mesh::targetShardingInUnsplitLastAxis().
|
inline |
Definition at line 68 of file MeshOps.h.
Referenced by equalHaloAndShardSizes(), mlir::mesh::reshardOn1DMesh(), and mlir::mesh::shardShapedType().
|
inline |
Definition at line 69 of file MeshOps.h.
Referenced by equalHaloAndShardSizes(), mlir::mesh::reshardOn1DMesh(), and mlir::mesh::shardShapedType().
bool MeshSharding::operator!= | ( | const MeshSharding & | rhs | ) | const |
Definition at line 621 of file MeshOps.cpp.
bool MeshSharding::operator!= | ( | Value | rhs | ) | const |
Definition at line 615 of file MeshOps.cpp.
bool MeshSharding::operator== | ( | const MeshSharding & | rhs | ) | const |
Definition at line 617 of file MeshOps.cpp.
References equalHaloAndShardSizes(), and equalSplitAndPartialAxes().
bool MeshSharding::operator== | ( | Value | rhs | ) | const |
Definition at line 611 of file MeshOps.cpp.
References equalHaloAndShardSizes(), and equalSplitAndPartialAxes().