MLIR  20.0.0git
Namespaces | Classes | Typedefs | Functions
mlir::mesh Namespace Reference

Namespaces

 detail
 

Classes

struct  ShardingOption
 
struct  IndependentParallelIteratorDomainShardingInterface
 
struct  ElementwiseShardingInterface
 
class  MeshSharding
 
struct  OpRewritePatternWithSymbolTableCollection
 

Typedefs

using ShardingArray = SmallVector< SmallVector< MeshAxis > >
 
using ShardingArrayRef = ArrayRef< SmallVector< MeshAxis > >
 
using MeshAxis = int16_t
 
using MeshAxesAttr = DenseI16ArrayAttr
 
using ShardShapeAttr = DenseI64ArrayAttr
 
using HaloSizePairAttr = DenseI64ArrayAttr
 
using UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Functions

FailureOr< std::pair< bool, MeshSharding > > getMeshSharding (OpResult result)
 
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding (OpOperand &opOperand)
 
void spmdizeFullyReplicatedOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
ShardingArray getMeshAxisAssignmentForLoopIterators (ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
 
bool isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
 
SmallVector< MeshAxisgetReductionMeshAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
 
void spmdizeTriviallyShardableOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
bool isReductionLoop (utils::IteratorType iType)
 
template<typename T >
void removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array)
 
bool isFullReplication (MeshSharding sharding)
 
mesh::MeshOp getMeshOrNull (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
 
mesh::MeshOp getMesh (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
 
template<typename Op >
mesh::MeshOp getMesh (Op op, SymbolTableCollection &symbolTableCollection)
 
template<>
mesh::MeshOp getMesh< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
 
template<typename MeshAxesRange , typename MeshShapeRange >
int64_t collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
 
template<typename MeshAxesRange >
int64_t collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshOp mesh)
 
int64_t shardDimension (int64_t dimSize, int64_t shardCount)
 
int64_t gatherDimension (int64_t dimSize, int64_t shardCount)
 
ShapedType shardShapedType (ShapedType shape, MeshOp mesh, MeshSharding sharding)
 
Type shardType (Type type, MeshOp mesh, MeshSharding sharding)
 
void maybeInsertTargetShardingAnnotation (MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
 
void maybeInsertTargetShardingAnnotation (MeshSharding sharding, OpResult result, OpBuilder &builder)
 
void maybeInsertSourceShardingAnnotation (MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
 
template<typename AlgebraicOp >
void populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction)
 
void populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void reshardingRegisterDependentDialects (DialectRegistry &registry)
 
void populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerProcessMultiIndexOpLoweringDialects (DialectRegistry &registry)
 
void populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllSliceOpLoweringDialects (DialectRegistry &registry)
 
void populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllOpLoweringDialects (DialectRegistry &registry)
 
TypedValue< IndexType > createCollectiveProcessGroupSize (MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder)
 
TypedValue< IndexType > createProcessLinearIndex (StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
 
template<typename SourceAxes , typename TargetAxes >
static bool arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static std::tuple< TypedValue< ShapedType >, MeshShardinghandlePartialAxesDuringResharding (OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static MeshSharding targetShardingInSplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static MeshSharding targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
 
static ShapedType allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingunsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static MeshSharding targetShardingInMoveLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingmoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
SmallVector< TypeshardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
static LogicalResult spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static std::vector< MeshShardinggetOperandShardings (Operation &op)
 
static std::vector< MeshShardinggetResultShardings (Operation &op)
 
static LogicalResult spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
 

Typedef Documentation

◆ HaloSizePairAttr

Definition at line 29 of file MeshOps.h.

◆ MeshAxesAttr

Definition at line 27 of file MeshOps.h.

◆ MeshAxis

using mlir::mesh::MeshAxis = typedef int16_t

Definition at line 26 of file MeshOps.h.

◆ ShardingArray

Definition at line 25 of file ShardingInterface.h.

◆ ShardingArrayRef

Definition at line 26 of file ShardingInterface.h.

◆ ShardShapeAttr

Definition at line 28 of file MeshOps.h.

◆ UnshardedToShardedValueMap

Definition at line 611 of file Spmdization.cpp.

Function Documentation

◆ allGatherResultShapeInUnsplitLastAxis()

static ShapedType mlir::mesh::allGatherResultShapeInUnsplitLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  splitTensorAxis 
)
static

Definition at line 249 of file Spmdization.cpp.

References gatherDimension().

Referenced by unsplitLastAxisInResharding().

◆ allToAllResultShapeInMoveLastAxis()

static ShapedType mlir::mesh::allToAllResultShapeInMoveLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

Definition at line 376 of file Spmdization.cpp.

References gatherDimension(), and shardDimension().

Referenced by moveLastSplitAxisInResharding().

◆ arePartialAxesCompatible()

template<typename SourceAxes , typename TargetAxes >
static bool mlir::mesh::arePartialAxesCompatible ( const SourceAxes &  sourceAxes,
const TargetAxes &  targetAxes 
)
static

Definition at line 44 of file Spmdization.cpp.

Referenced by handlePartialAxesDuringResharding().

◆ collectiveProcessGroupSize() [1/2]

template<typename MeshAxesRange >
int64_t mlir::mesh::collectiveProcessGroupSize ( MeshAxesRange &&  meshAxes,
MeshOp  mesh 
)

Definition at line 167 of file MeshOps.h.

References collectiveProcessGroupSize().

◆ collectiveProcessGroupSize() [2/2]

template<typename MeshAxesRange , typename MeshShapeRange >
int64_t mlir::mesh::collectiveProcessGroupSize ( MeshAxesRange &&  meshAxes,
MeshShapeRange &&  meshShape 
)

◆ createCollectiveProcessGroupSize()

TypedValue< IndexType > mlir::mesh::createCollectiveProcessGroupSize ( MeshOp  mesh,
ArrayRef< MeshAxis axes,
ImplicitLocOpBuilder builder 
)

◆ createProcessLinearIndex()

TypedValue< IndexType > mlir::mesh::createProcessLinearIndex ( StringRef  mesh,
ArrayRef< MeshAxis meshAxes,
ImplicitLocOpBuilder builder 
)

◆ detectMoveLastSplitAxisInResharding()

static std::optional<std::tuple<int64_t, int64_t, MeshAxis> > mlir::mesh::detectMoveLastSplitAxisInResharding ( MeshSharding  sourceSharding,
MeshSharding  targetSharding 
)
static

◆ detectSplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, MeshAxis> > mlir::mesh::detectSplitLastAxisInResharding ( MeshSharding  sourceSharding,
MeshSharding  targetSharding 
)
static

Definition at line 153 of file Spmdization.cpp.

References mlir::mesh::MeshSharding::getSplitAxes().

Referenced by trySplitLastAxisInResharding().

◆ detectUnsplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, MeshAxis> > mlir::mesh::detectUnsplitLastAxisInResharding ( MeshSharding  sourceSharding,
MeshSharding  targetSharding 
)
static

Definition at line 203 of file Spmdization.cpp.

References mlir::mesh::MeshSharding::getSplitAxes().

Referenced by tryUnsplitLastAxisInResharding().

◆ gatherDimension()

int64_t mlir::mesh::gatherDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

◆ getMesh() [1/2]

template<typename Op >
mesh::MeshOp mlir::mesh::getMesh ( Op  op,
SymbolTableCollection symbolTableCollection 
)

Definition at line 135 of file MeshOps.h.

References getMesh(), and mlir::Op< ConcreteType, Traits >::getOperation().

◆ getMesh() [2/2]

mesh::MeshOp mlir::mesh::getMesh ( Operation op,
FlatSymbolRefAttr  meshSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

◆ getMesh< ShardOp >()

template<>
mesh::MeshOp mlir::mesh::getMesh< ShardOp > ( ShardOp  op,
SymbolTableCollection symbolTableCollection 
)
inline

Definition at line 140 of file MeshOps.h.

References getMesh().

◆ getMeshAxisAssignmentForLoopIterators()

ShardingArray mlir::mesh::getMeshAxisAssignmentForLoopIterators ( ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< AffineMap indexingMaps 
)

Definition at line 634 of file ShardingInterface.cpp.

References updateMeshAxisAssignmentForLoopIterators().

◆ getMeshOrNull()

mesh::MeshOp mlir::mesh::getMeshOrNull ( Operation op,
FlatSymbolRefAttr  meshSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

Definition at line 120 of file MeshOps.h.

References mlir::SymbolTableCollection::lookupNearestSymbolFrom().

Referenced by getMesh(), and getMeshAndVerify().

◆ getMeshSharding() [1/2]

FailureOr< std::pair< bool, MeshSharding > > mlir::mesh::getMeshSharding ( OpOperand opOperand)

◆ getMeshSharding() [2/2]

FailureOr< std::pair< bool, MeshSharding > > mlir::mesh::getMeshSharding ( OpResult  result)

Definition at line 109 of file ShardingInterface.cpp.

References getSharding(), mlir::Value::getUsers(), and mlir::Value::hasOneUse().

Referenced by visitOp().

◆ getOperandShardings()

static std::vector<MeshSharding> mlir::mesh::getOperandShardings ( Operation op)
static

◆ getReductionMeshAxes()

SmallVector< MeshAxis > mlir::mesh::getReductionMeshAxes ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< MeshAxis >>  meshAxisAssignmentForLoopIterators 
)

◆ getResultShardings()

static std::vector<MeshSharding> mlir::mesh::getResultShardings ( Operation op)
static

◆ handlePartialAxesDuringResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::handlePartialAxesDuringResharding ( OpBuilder builder,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

◆ isAtLeastOneReductionIteratorSharded()

bool mlir::mesh::isAtLeastOneReductionIteratorSharded ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< MeshAxis >>  meshAxisAssignmentForLoopIterators 
)

Definition at line 675 of file ShardingInterface.cpp.

◆ isFullReplication()

bool mlir::mesh::isFullReplication ( MeshSharding  sharding)
inline

◆ isReductionLoop()

bool mlir::mesh::isReductionLoop ( utils::IteratorType  iType)
inline

Definition at line 100 of file MeshOps.h.

Referenced by mlir::mesh::detail::defaultGetShardingOption(), and getSharding().

◆ maybeInsertSourceShardingAnnotation()

void mlir::mesh::maybeInsertSourceShardingAnnotation ( MeshSharding  sharding,
OpOperand operand,
OpBuilder builder 
)

◆ maybeInsertTargetShardingAnnotation() [1/2]

void mlir::mesh::maybeInsertTargetShardingAnnotation ( MeshSharding  sharding,
OpOperand operand,
OpBuilder builder 
)

◆ maybeInsertTargetShardingAnnotation() [2/2]

void mlir::mesh::maybeInsertTargetShardingAnnotation ( MeshSharding  sharding,
OpResult  result,
OpBuilder builder 
)

Definition at line 306 of file MeshOps.cpp.

References mlir::Value::getUses(), and maybeInsertTargetShardingAnnotation().

◆ moveLastSplitAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::moveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis,
MeshAxis  meshAxis 
)
static

◆ populateAllOpLoweringPatterns()

void mlir::mesh::populateAllOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateAllReduceEndomorphismSimplificationPatterns()

template<typename AlgebraicOp >
void mlir::mesh::populateAllReduceEndomorphismSimplificationPatterns ( RewritePatternSet patterns,
ReductionKind  reduction 
)

Definition at line 40 of file Simplifications.h.

References mlir::patterns.

◆ populateAllSliceOpLoweringPatterns()

void mlir::mesh::populateAllSliceOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 177 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateFoldingPatterns()

void mlir::mesh::populateFoldingPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 116 of file Simplifications.cpp.

References mlir::patterns.

Referenced by populateSimplificationPatterns().

◆ populateProcessMultiIndexOpLoweringPatterns()

void mlir::mesh::populateProcessMultiIndexOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 167 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateSimplificationPatterns()

void mlir::mesh::populateSimplificationPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 25 of file Simplifications.cpp.

References mlir::patterns, and populateFoldingPatterns().

◆ registerAllOpLoweringDialects()

void mlir::mesh::registerAllOpLoweringDialects ( DialectRegistry registry)

◆ registerAllSliceOpLoweringDialects()

void mlir::mesh::registerAllSliceOpLoweringDialects ( DialectRegistry registry)

Definition at line 183 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerProcessMultiIndexOpLoweringDialects()

void mlir::mesh::registerProcessMultiIndexOpLoweringDialects ( DialectRegistry registry)

Definition at line 173 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ removeTrailingEmptySubArray()

template<typename T >
void mlir::mesh::removeTrailingEmptySubArray ( SmallVector< SmallVector< T >> &  array)

Definition at line 106 of file MeshOps.h.

Referenced by mlir::mesh::detail::defaultGetShardingOption(), and getSharding().

◆ reshard() [1/3]

TypedValue<ShapedType> mlir::mesh::reshard ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)

Definition at line 558 of file Spmdization.cpp.

References reshardOn1DMesh(), and tryUpdateHaloInResharding().

Referenced by spmdizeOperation().

◆ reshard() [2/3]

TypedValue< ShapedType > mlir::mesh::reshard ( OpBuilder builder,
MeshOp  mesh,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue 
)

Definition at line 583 of file Spmdization.cpp.

Referenced by reshard().

◆ reshard() [3/3]

TypedValue< ShapedType > mlir::mesh::reshard ( OpBuilder builder,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue,
SymbolTableCollection symbolTableCollection 
)

Definition at line 595 of file Spmdization.cpp.

References getMesh(), and reshard().

◆ reshardingRegisterDependentDialects()

void mlir::mesh::reshardingRegisterDependentDialects ( DialectRegistry registry)

Definition at line 604 of file Spmdization.cpp.

References mlir::DialectRegistry::insert().

◆ reshardOn1DMesh()

static TypedValue<ShapedType> mlir::mesh::reshardOn1DMesh ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)
static

◆ shardDimension()

int64_t mlir::mesh::shardDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

Definition at line 173 of file MeshOps.h.

Referenced by allToAllResultShapeInMoveLastAxis().

◆ shardedBlockArgumentTypes()

SmallVector<Type> mlir::mesh::shardedBlockArgumentTypes ( Block block,
SymbolTableCollection symbolTableCollection 
)

◆ shardShapedType()

ShapedType mlir::mesh::shardShapedType ( ShapedType  shape,
MeshOp  mesh,
MeshSharding  sharding 
)

◆ shardType()

Type mlir::mesh::shardType ( Type  type,
MeshOp  mesh,
MeshSharding  sharding 
)

Definition at line 264 of file MeshOps.cpp.

References shardShapedType().

Referenced by spmdizeTriviallyShardableOperation().

◆ splitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::splitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshSharding  sourceSharding,
TypedValue< ShapedType >  sourceShard,
MeshOp  mesh,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static

◆ spmdizeBlock()

static LogicalResult mlir::mesh::spmdizeBlock ( Block block,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeFullyReplicatedOperation()

void mlir::mesh::spmdizeFullyReplicatedOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

◆ spmdizeFuncOp()

static LogicalResult mlir::mesh::spmdizeFuncOp ( FunctionOpInterface  op,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection 
)
static

Definition at line 792 of file Spmdization.cpp.

References mlir::get(), mlir::Operation::getOperandTypes(), and spmdizeBlock().

◆ spmdizeOperation() [1/3]

static LogicalResult mlir::mesh::spmdizeOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeOperation() [2/3]

static LogicalResult mlir::mesh::spmdizeOperation ( Operation op,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeOperation() [3/3]

static LogicalResult mlir::mesh::spmdizeOperation ( ShardOp  shardOp,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeTriviallyShardableOperation()

void mlir::mesh::spmdizeTriviallyShardableOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

◆ targetShardingInMoveLastAxis()

static MeshSharding mlir::mesh::targetShardingInMoveLastAxis ( MLIRContext ctx,
MeshSharding  sourceSharding,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

◆ targetShardingInSplitLastAxis()

static MeshSharding mlir::mesh::targetShardingInSplitLastAxis ( MLIRContext ctx,
MeshSharding  sourceSharding,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static

◆ targetShardingInUnsplitLastAxis()

static MeshSharding mlir::mesh::targetShardingInUnsplitLastAxis ( MLIRContext ctx,
MeshSharding  sourceSharding,
int64_t  splitTensorAxis 
)
static

◆ tryMoveLastSplitAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::tryMoveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ trySplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::trySplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 185 of file Spmdization.cpp.

References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().

Referenced by reshardOn1DMesh().

◆ tryUnsplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::tryUnsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 283 of file Spmdization.cpp.

References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().

Referenced by reshardOn1DMesh().

◆ tryUpdateHaloInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::tryUpdateHaloInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ unsplitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::unsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshSharding  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
MeshOp  mesh,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static