MLIR
20.0.0git
|
Namespaces | |
detail | |
Classes | |
struct | ShardingOption |
struct | IndependentParallelIteratorDomainShardingInterface |
struct | ElementwiseShardingInterface |
class | MeshSharding |
struct | OpRewritePatternWithSymbolTableCollection |
Typedefs | |
using | ShardingArray = SmallVector< SmallVector< MeshAxis > > |
using | ShardingArrayRef = ArrayRef< SmallVector< MeshAxis > > |
using | MeshAxis = int16_t |
using | MeshAxesAttr = DenseI16ArrayAttr |
using | ShardShapeAttr = DenseI64ArrayAttr |
using | HaloSizePairAttr = DenseI64ArrayAttr |
using | UnshardedToShardedValueMap = DenseMap< Value, Value > |
Functions | |
FailureOr< std::pair< bool, MeshSharding > > | getMeshSharding (OpResult result) |
FailureOr< std::pair< bool, MeshSharding > > | getMeshSharding (OpOperand &opOperand) |
void | spmdizeFullyReplicatedOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
ShardingArray | getMeshAxisAssignmentForLoopIterators (ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators) |
SmallVector< MeshAxis > | getReductionMeshAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators) |
void | spmdizeTriviallyShardableOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
bool | isReductionLoop (utils::IteratorType iType) |
template<typename T > | |
void | removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array) |
bool | isFullReplication (MeshSharding sharding) |
mesh::MeshOp | getMeshOrNull (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection) |
mesh::MeshOp | getMesh (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection) |
template<typename Op > | |
mesh::MeshOp | getMesh (Op op, SymbolTableCollection &symbolTableCollection) |
template<> | |
mesh::MeshOp | getMesh< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
template<typename MeshAxesRange , typename MeshShapeRange > | |
int64_t | collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape) |
template<typename MeshAxesRange > | |
int64_t | collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshOp mesh) |
int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
ShapedType | shardShapedType (ShapedType shape, MeshOp mesh, MeshSharding sharding) |
Type | shardType (Type type, MeshOp mesh, MeshSharding sharding) |
void | maybeInsertTargetShardingAnnotation (MeshSharding sharding, OpOperand &operand, OpBuilder &builder) |
void | maybeInsertTargetShardingAnnotation (MeshSharding sharding, OpResult result, OpBuilder &builder) |
void | maybeInsertSourceShardingAnnotation (MeshSharding sharding, OpOperand &operand, OpBuilder &builder) |
template<typename AlgebraicOp > | |
void | populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
void | populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
TypedValue< ShapedType > | reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
TypedValue< IndexType > | createCollectiveProcessGroupSize (MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder) |
TypedValue< IndexType > | createProcessLinearIndex (StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder) |
template<typename SourceAxes , typename TargetAxes > | |
static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
static std::tuple< TypedValue< ShapedType >, MeshSharding > | handlePartialAxesDuringResharding (OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard) |
static MeshSharding | targetShardingInSplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
static std::tuple< TypedValue< ShapedType >, MeshSharding > | splitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
static std::optional< std::tuple< int64_t, MeshAxis > > | detectSplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > | trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< int64_t, MeshAxis > > | detectUnsplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding) |
static MeshSharding | targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis) |
static ShapedType | allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) |
static std::tuple< TypedValue< ShapedType >, MeshSharding > | unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > | tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > | detectMoveLastSplitAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding) |
static MeshSharding | targetShardingInMoveLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
static ShapedType | allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
static std::tuple< TypedValue< ShapedType >, MeshSharding > | moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > | tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > | tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static TypedValue< ShapedType > | reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
static LogicalResult | spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static std::vector< MeshSharding > | getOperandShardings (Operation &op) |
static std::vector< MeshSharding > | getResultShardings (Operation &op) |
static LogicalResult | spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection) |
using mlir::mesh::HaloSizePairAttr = typedef DenseI64ArrayAttr |
using mlir::mesh::MeshAxesAttr = typedef DenseI16ArrayAttr |
using mlir::mesh::MeshAxis = typedef int16_t |
using mlir::mesh::ShardingArray = typedef SmallVector<SmallVector<MeshAxis> > |
Definition at line 25 of file ShardingInterface.h.
using mlir::mesh::ShardingArrayRef = typedef ArrayRef<SmallVector<MeshAxis> > |
Definition at line 26 of file ShardingInterface.h.
using mlir::mesh::ShardShapeAttr = typedef DenseI64ArrayAttr |
using mlir::mesh::UnshardedToShardedValueMap = typedef DenseMap<Value, Value> |
Definition at line 613 of file Spmdization.cpp.
|
static |
Definition at line 249 of file Spmdization.cpp.
References gatherDimension().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 376 of file Spmdization.cpp.
References gatherDimension(), and shardDimension().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 44 of file Spmdization.cpp.
Referenced by handlePartialAxesDuringResharding().
int64_t mlir::mesh::collectiveProcessGroupSize | ( | MeshAxesRange && | meshAxes, |
MeshOp | mesh | ||
) |
Definition at line 167 of file MeshOps.h.
References collectiveProcessGroupSize().
int64_t mlir::mesh::collectiveProcessGroupSize | ( | MeshAxesRange && | meshAxes, |
MeshShapeRange && | meshShape | ||
) |
Definition at line 151 of file MeshOps.h.
Referenced by collectiveProcessGroupSize(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
TypedValue< IndexType > mlir::mesh::createCollectiveProcessGroupSize | ( | MeshOp | mesh, |
ArrayRef< MeshAxis > | axes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 201 of file Transforms.cpp.
References mlir::ImplicitLocOpBuilder::create(), mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
TypedValue< IndexType > mlir::mesh::createProcessLinearIndex | ( | StringRef | mesh, |
ArrayRef< MeshAxis > | meshAxes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 210 of file Transforms.cpp.
References mlir::ImplicitLocOpBuilder::create(), and mlir::affine::linearizeIndex().
Referenced by mlir::linalg::createDestinationPassingStyleInitOperand().
|
static |
Definition at line 305 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::getSplitAxes().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 153 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::getSplitAxes().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 203 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::getSplitAxes().
Referenced by tryUnsplitLastAxisInResharding().
|
inline |
Definition at line 182 of file MeshOps.h.
Referenced by allGatherResultShapeInUnsplitLastAxis(), and allToAllResultShapeInMoveLastAxis().
mesh::MeshOp mlir::mesh::getMesh | ( | Op | op, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 135 of file MeshOps.h.
References getMesh(), and mlir::Op< ConcreteType, Traits >::getOperation().
|
inline |
Definition at line 126 of file MeshOps.h.
References getMeshOrNull().
Referenced by getMesh(), mlir::linalg::getMesh(), getMesh< ShardOp >(), reshard(), shardedBlockArgumentTypes(), and spmdizeTriviallyShardableOperation().
|
inline |
ShardingArray mlir::mesh::getMeshAxisAssignmentForLoopIterators | ( | ArrayRef< MeshSharding > | operandShardings, |
ArrayRef< MeshSharding > | resultShardings, | ||
ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
ArrayRef< AffineMap > | indexingMaps | ||
) |
Definition at line 634 of file ShardingInterface.cpp.
References updateMeshAxisAssignmentForLoopIterators().
|
inline |
Definition at line 120 of file MeshOps.h.
References mlir::SymbolTableCollection::lookupNearestSymbolFrom().
Referenced by getMesh(), and getMeshAndVerify().
FailureOr< std::pair< bool, MeshSharding > > mlir::mesh::getMeshSharding | ( | OpOperand & | opOperand | ) |
Definition at line 153 of file ShardingInterface.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
FailureOr< std::pair< bool, MeshSharding > > mlir::mesh::getMeshSharding | ( | OpResult | result | ) |
Definition at line 109 of file ShardingInterface.cpp.
References getSharding(), mlir::Value::getUsers(), and mlir::Value::hasOneUse().
Referenced by visitOp().
|
static |
Definition at line 678 of file Spmdization.cpp.
References mlir::Value::getDefiningOp(), mlir::Operation::getNumOperands(), and mlir::Operation::getOperands().
Referenced by spmdizeOperation().
SmallVector< MeshAxis > mlir::mesh::getReductionMeshAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
ArrayRef< SmallVector< MeshAxis >> | meshAxisAssignmentForLoopIterators | ||
) |
Definition at line 688 of file ShardingInterface.cpp.
Referenced by mlir::linalg::spmdizeLinalgOpWithShardedReduction().
|
static |
Definition at line 698 of file Spmdization.cpp.
References mlir::Operation::getNumResults(), mlir::Operation::getResults(), mlir::Value::getUsers(), and mlir::Value::hasOneUse().
Referenced by spmdizeOperation().
|
static |
Definition at line 58 of file Spmdization.cpp.
References arePartialAxesCompatible(), mlir::OpBuilder::create(), mlir::mesh::MeshSharding::get(), mlir::mesh::MeshSharding::getMeshAttr(), mlir::mesh::MeshSharding::getPartialAxes(), mlir::mesh::MeshSharding::getPartialType(), mlir::mesh::MeshSharding::getSplitAxes(), and mlir::OpBuilder::setInsertionPointAfterValue().
Referenced by reshardOn1DMesh().
bool mlir::mesh::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
ArrayRef< SmallVector< MeshAxis >> | meshAxisAssignmentForLoopIterators | ||
) |
Definition at line 675 of file ShardingInterface.cpp.
|
inline |
Definition at line 112 of file MeshOps.h.
References mlir::mesh::MeshSharding::getPartialAxes(), and mlir::mesh::MeshSharding::getSplitAxes().
Referenced by isValueCompatibleWithFullReplicationSharding().
|
inline |
Definition at line 100 of file MeshOps.h.
Referenced by mlir::mesh::detail::defaultGetShardingOption(), and getSharding().
void mlir::mesh::maybeInsertSourceShardingAnnotation | ( | MeshSharding | sharding, |
OpOperand & | operand, | ||
OpBuilder & | builder | ||
) |
Definition at line 314 of file MeshOps.cpp.
References mlir::OpBuilder::create(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getDefiningOp(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Operation::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by addShardOp().
void mlir::mesh::maybeInsertTargetShardingAnnotation | ( | MeshSharding | sharding, |
OpOperand & | operand, | ||
OpBuilder & | builder | ||
) |
Definition at line 272 of file MeshOps.cpp.
References mlir::OpBuilder::create(), mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), and mlir::OpBuilder::setInsertionPointAfterValue().
Referenced by addShardOp(), and maybeInsertTargetShardingAnnotation().
void mlir::mesh::maybeInsertTargetShardingAnnotation | ( | MeshSharding | sharding, |
OpResult | result, | ||
OpBuilder & | builder | ||
) |
Definition at line 306 of file MeshOps.cpp.
References mlir::Value::getUses(), and maybeInsertTargetShardingAnnotation().
|
static |
Definition at line 389 of file Spmdization.cpp.
References allToAllResultShapeInMoveLastAxis(), mlir::ImplicitLocOpBuilder::create(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInMoveLastAxis().
Referenced by tryMoveLastSplitAxisInResharding().
void mlir::mesh::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 189 of file Transforms.cpp.
References populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
void mlir::mesh::populateAllReduceEndomorphismSimplificationPatterns | ( | RewritePatternSet & | patterns, |
ReductionKind | reduction | ||
) |
Definition at line 40 of file Simplifications.h.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
void mlir::mesh::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 177 of file Transforms.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateAllOpLoweringPatterns().
void mlir::mesh::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 116 of file Simplifications.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateSimplificationPatterns().
void mlir::mesh::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 167 of file Transforms.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateAllOpLoweringPatterns().
void mlir::mesh::populateSimplificationPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 25 of file Simplifications.cpp.
References populateFoldingPatterns().
void mlir::mesh::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 195 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
void mlir::mesh::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 183 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
void mlir::mesh::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 173 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
void mlir::mesh::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T >> & | array | ) |
Definition at line 106 of file MeshOps.h.
Referenced by mlir::mesh::detail::defaultGetShardingOption(), and getSharding().
TypedValue<ShapedType> mlir::mesh::reshard | ( | ImplicitLocOpBuilder & | builder, |
MeshOp | mesh, | ||
MeshSharding | sourceSharding, | ||
MeshSharding | targetSharding, | ||
TypedValue< ShapedType > | sourceUnshardedValue, | ||
TypedValue< ShapedType > | sourceShard | ||
) |
Definition at line 560 of file Spmdization.cpp.
References reshardOn1DMesh(), and tryUpdateHaloInResharding().
Referenced by spmdizeOperation().
TypedValue< ShapedType > mlir::mesh::reshard | ( | OpBuilder & | builder, |
MeshOp | mesh, | ||
ShardOp | source, | ||
ShardOp | target, | ||
TypedValue< ShapedType > | sourceShardValue | ||
) |
Definition at line 585 of file Spmdization.cpp.
Referenced by reshard().
TypedValue< ShapedType > mlir::mesh::reshard | ( | OpBuilder & | builder, |
ShardOp | source, | ||
ShardOp | target, | ||
TypedValue< ShapedType > | sourceShardValue, | ||
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 597 of file Spmdization.cpp.
void mlir::mesh::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 606 of file Spmdization.cpp.
References mlir::DialectRegistry::insert().
|
static |
Definition at line 515 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::getStaticHaloSizes(), mlir::mesh::MeshSharding::getStaticShardedDimsOffsets(), handlePartialAxesDuringResharding(), shardShapedType(), tryMoveLastSplitAxisInResharding(), trySplitLastAxisInResharding(), and tryUnsplitLastAxisInResharding().
Referenced by reshard().
|
inline |
Definition at line 173 of file MeshOps.h.
Referenced by allToAllResultShapeInMoveLastAxis().
SmallVector<Type> mlir::mesh::shardedBlockArgumentTypes | ( | Block & | block, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 619 of file Spmdization.cpp.
References mlir::Block::getArguments(), getMesh(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by spmdizeBlock().
ShapedType mlir::mesh::shardShapedType | ( | ShapedType | shape, |
MeshOp | mesh, | ||
MeshSharding | sharding | ||
) |
Definition at line 254 of file MeshOps.cpp.
References mlir::mesh::MeshSharding::getSplitAxes(), mlir::mesh::MeshSharding::getStaticHaloSizes(), mlir::mesh::MeshSharding::getStaticShardedDimsOffsets(), and shardShape().
Referenced by moveLastSplitAxisInResharding(), reshardOn1DMesh(), shardedBlockArgumentTypes(), shardType(), and unsplitLastAxisInResharding().
Type mlir::mesh::shardType | ( | Type | type, |
MeshOp | mesh, | ||
MeshSharding | sharding | ||
) |
Definition at line 264 of file MeshOps.cpp.
References shardShapedType().
Referenced by spmdizeTriviallyShardableOperation().
|
static |
Definition at line 132 of file Spmdization.cpp.
References mlir::ImplicitLocOpBuilder::create(), mlir::Builder::getContext(), and targetShardingInSplitLastAxis().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 767 of file Spmdization.cpp.
References mlir::OpBuilder::createBlock(), mlir::Block::getArguments(), mlir::Block::getOperations(), mlir::Block::getParent(), mlir::IRMapping::map(), mlir::OpBuilder::setInsertionPointToEnd(), shardedBlockArgumentTypes(), and spmdizeOperation().
Referenced by spmdizeFuncOp().
void mlir::mesh::spmdizeFullyReplicatedOperation | ( | Operation & | op, |
ArrayRef< Value > | spmdizedOperands, | ||
ArrayRef< MeshSharding > | operandShardings, | ||
ArrayRef< MeshSharding > | resultShardings, | ||
IRMapping & | spmdizationMap, | ||
SymbolTableCollection & | symbolTable, | ||
OpBuilder & | builder | ||
) |
Definition at line 605 of file ShardingInterface.cpp.
References areValuesCompatibleWithFullReplicationShardings(), mlir::OpBuilder::clone(), mlir::Operation::getOperands(), and mlir::Operation::getResults().
Referenced by spmdizeOperation().
|
static |
Definition at line 794 of file Spmdization.cpp.
References mlir::get(), mlir::Operation::getOperandTypes(), and spmdizeBlock().
|
static |
Definition at line 649 of file Spmdization.cpp.
References mlir::Operation::getResults(), and spmdizeFullyReplicatedOperation().
|
static |
Definition at line 743 of file Spmdization.cpp.
References mlir::Operation::getOperands(), getOperandShardings(), and getResultShardings().
Referenced by spmdizeBlock().
|
static |
Definition at line 718 of file Spmdization.cpp.
References mlir::IRMapping::contains(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), and reshard().
void mlir::mesh::spmdizeTriviallyShardableOperation | ( | Operation & | op, |
ArrayRef< Value > | spmdizedOperands, | ||
ArrayRef< MeshSharding > | operandShardings, | ||
ArrayRef< MeshSharding > | resultShardings, | ||
IRMapping & | spmdizationMap, | ||
SymbolTableCollection & | symbolTable, | ||
OpBuilder & | builder | ||
) |
Definition at line 701 of file ShardingInterface.cpp.
References mlir::OpBuilder::clone(), getMesh(), mlir::Operation::getResults(), and shardType().
Referenced by mlir::mesh::IndependentParallelIteratorDomainShardingInterface< Op >::spmdize(), mlir::mesh::ElementwiseShardingInterface< ElemwiseOp >::spmdize(), and mlir::linalg::spmdizeLinalgOpWithShardedReduction().
|
static |
Definition at line 346 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::mesh::MeshSharding::getMeshAttr(), mlir::mesh::MeshSharding::getPartialAxes(), mlir::mesh::MeshSharding::getPartialType(), and mlir::mesh::MeshSharding::getSplitAxes().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 108 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::mesh::MeshSharding::getMeshAttr(), mlir::mesh::MeshSharding::getPartialAxes(), mlir::mesh::MeshSharding::getPartialType(), and mlir::mesh::MeshSharding::getSplitAxes().
Referenced by splitLastAxisInResharding().
|
static |
Definition at line 231 of file Spmdization.cpp.
References mlir::mesh::MeshSharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::mesh::MeshSharding::getMeshAttr(), mlir::mesh::MeshSharding::getPartialAxes(), mlir::mesh::MeshSharding::getPartialType(), and mlir::mesh::MeshSharding::getSplitAxes().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 416 of file Spmdization.cpp.
References detectMoveLastSplitAxisInResharding(), and moveLastSplitAxisInResharding().
Referenced by reshardOn1DMesh().
|
static |
Definition at line 185 of file Spmdization.cpp.
References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().
Referenced by reshardOn1DMesh().
|
static |
Definition at line 283 of file Spmdization.cpp.
References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().
Referenced by reshardOn1DMesh().
|
static |
Definition at line 437 of file Spmdization.cpp.
References mlir::ImplicitLocOpBuilder::create(), mlir::mesh::MeshSharding::equalHaloSizes(), mlir::mesh::MeshSharding::equalSplitAndPartialAxes(), mlir::get(), mlir::Builder::getContext(), mlir::mesh::MeshSharding::getDynamicHaloSizes(), mlir::mesh::MeshSharding::getPartialAxes(), mlir::mesh::MeshSharding::getSplitAxes(), mlir::mesh::MeshSharding::getStaticHaloSizes(), and mlir::mesh::MeshSharding::getStaticShardedDimsOffsets().
Referenced by reshard().
|
static |
Definition at line 258 of file Spmdization.cpp.
References allGatherResultShapeInUnsplitLastAxis(), mlir::ImplicitLocOpBuilder::create(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInUnsplitLastAxis().
Referenced by tryUnsplitLastAxisInResharding().