MLIR  19.0.0git
Namespaces | Classes | Typedefs | Functions
mlir::mesh Namespace Reference

Namespaces

 detail
 

Classes

struct  ShardingOption
 
struct  IndependentParallelIteratorDomainShardingInterface
 
struct  ElementwiseShardingInterface
 
struct  OpRewritePatternWithSymbolTableCollection
 

Typedefs

using ShardingArray = SmallVector< SmallVector< MeshAxis > >
 
using ShardingArrayRef = ArrayRef< SmallVector< MeshAxis > >
 
using MeshAxis = int16_t
 
using MeshAxesAttr = DenseI16ArrayAttr
 
using UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Functions

FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr (OpResult result)
 
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr (OpOperand &opOperand)
 
void spmdizeFullyReplicatedOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
ShardingArray getMeshAxisAssignmentForLoopIterators (ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
 
bool isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
 
SmallVector< MeshAxisgetReductionMeshAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
 
void spmdizeTriviallyShardableOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
bool isReductionLoop (utils::IteratorType iType)
 
template<typename T >
void removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array)
 
bool isFullReplication (MeshShardingAttr attr)
 
mesh::MeshOp getMesh (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
 
template<typename Op >
mesh::MeshOp getMesh (Op op, SymbolTableCollection &symbolTableCollection)
 
template<>
mesh::MeshOp getMesh< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
 
template<typename MeshAxesRange , typename MeshShapeRange >
int64_t collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
 
template<typename MeshAxesRange >
int64_t collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshOp mesh)
 
int64_t shardDimension (int64_t dimSize, int64_t shardCount)
 
int64_t gatherDimension (int64_t dimSize, int64_t shardCount)
 
ShapedType shardShapedType (ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
 
Type shardType (Type type, MeshOp mesh, MeshShardingAttr sharding)
 
template<typename AlgebraicOp >
void populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction)
 
void populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void reshardingRegisterDependentDialects (DialectRegistry &registry)
 
void populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerProcessMultiIndexOpLoweringDialects (DialectRegistry &registry)
 
void populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllSliceOpLoweringDialects (DialectRegistry &registry)
 
void populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllOpLoweringDialects (DialectRegistry &registry)
 
TypedValue< IndexType > createCollectiveProcessGroupSize (MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder)
 
TypedValue< IndexType > createProcessLinearIndex (StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
 
template<typename SourceAxes , typename TargetAxes >
static bool arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > handlePartialAxesDuringResharding (OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
 
static MeshShardingAttr targetShardingInSplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > splitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
 
static MeshShardingAttr targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
 
static ShapedType allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
 
static MeshShardingAttr targetShardingInMoveLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
SmallVector< TypeshardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
static LogicalResult spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static SmallVector< MeshShardingAttr > getOperandShardings (Operation &op)
 
static SmallVector< MeshShardingAttr > getResultShardings (Operation &op)
 
static LogicalResult spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
 

Typedef Documentation

◆ MeshAxesAttr

Definition at line 26 of file MeshOps.h.

◆ MeshAxis

using mlir::mesh::MeshAxis = typedef int16_t

Definition at line 25 of file MeshOps.h.

◆ ShardingArray

Definition at line 25 of file ShardingInterface.h.

◆ ShardingArrayRef

Definition at line 26 of file ShardingInterface.h.

◆ UnshardedToShardedValueMap

Definition at line 521 of file Spmdization.cpp.

Function Documentation

◆ allGatherResultShapeInUnsplitLastAxis()

static ShapedType mlir::mesh::allGatherResultShapeInUnsplitLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  splitTensorAxis 
)
static

Definition at line 251 of file Spmdization.cpp.

References gatherDimension().

Referenced by unsplitLastAxisInResharding().

◆ allToAllResultShapeInMoveLastAxis()

static ShapedType mlir::mesh::allToAllResultShapeInMoveLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

Definition at line 378 of file Spmdization.cpp.

References gatherDimension(), and shardDimension().

Referenced by moveLastSplitAxisInResharding().

◆ arePartialAxesCompatible()

template<typename SourceAxes , typename TargetAxes >
static bool mlir::mesh::arePartialAxesCompatible ( const SourceAxes &  sourceAxes,
const TargetAxes &  targetAxes 
)
static

Definition at line 45 of file Spmdization.cpp.

Referenced by handlePartialAxesDuringResharding().

◆ collectiveProcessGroupSize() [1/2]

template<typename MeshAxesRange >
int64_t mlir::mesh::collectiveProcessGroupSize ( MeshAxesRange &&  meshAxes,
MeshOp  mesh 
)

Definition at line 95 of file MeshOps.h.

References collectiveProcessGroupSize().

◆ collectiveProcessGroupSize() [2/2]

template<typename MeshAxesRange , typename MeshShapeRange >
int64_t mlir::mesh::collectiveProcessGroupSize ( MeshAxesRange &&  meshAxes,
MeshShapeRange &&  meshShape 
)

◆ createCollectiveProcessGroupSize()

TypedValue< IndexType > mlir::mesh::createCollectiveProcessGroupSize ( MeshOp  mesh,
ArrayRef< MeshAxis axes,
ImplicitLocOpBuilder builder 
)

◆ createProcessLinearIndex()

TypedValue< IndexType > mlir::mesh::createProcessLinearIndex ( StringRef  mesh,
ArrayRef< MeshAxis meshAxes,
ImplicitLocOpBuilder builder 
)

◆ detectMoveLastSplitAxisInResharding()

static std::optional<std::tuple<int64_t, int64_t, MeshAxis> > mlir::mesh::detectMoveLastSplitAxisInResharding ( MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding 
)
static

Definition at line 307 of file Spmdization.cpp.

Referenced by tryMoveLastSplitAxisInResharding().

◆ detectSplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, MeshAxis> > mlir::mesh::detectSplitLastAxisInResharding ( MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding 
)
static

Definition at line 154 of file Spmdization.cpp.

Referenced by trySplitLastAxisInResharding().

◆ detectUnsplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, MeshAxis> > mlir::mesh::detectUnsplitLastAxisInResharding ( MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding 
)
static

Definition at line 204 of file Spmdization.cpp.

Referenced by tryUnsplitLastAxisInResharding().

◆ gatherDimension()

int64_t mlir::mesh::gatherDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

◆ getMesh() [1/2]

template<typename Op >
mesh::MeshOp mlir::mesh::getMesh ( Op  op,
SymbolTableCollection symbolTableCollection 
)

Definition at line 65 of file MeshOps.h.

◆ getMesh() [2/2]

mesh::MeshOp mlir::mesh::getMesh ( Operation op,
FlatSymbolRefAttr  meshSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

Definition at line 57 of file MeshOps.h.

Referenced by reshard(), and shardedBlockArgumentTypes().

◆ getMesh< ShardOp >()

template<>
mesh::MeshOp mlir::mesh::getMesh< ShardOp > ( ShardOp  op,
SymbolTableCollection symbolTableCollection 
)
inline

Definition at line 70 of file MeshOps.h.

◆ getMeshAxisAssignmentForLoopIterators()

ShardingArray mlir::mesh::getMeshAxisAssignmentForLoopIterators ( ArrayRef< MeshShardingAttr >  operandShardings,
ArrayRef< MeshShardingAttr >  resultShardings,
ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< AffineMap indexingMaps 
)

Definition at line 582 of file ShardingInterface.cpp.

References updateMeshAxisAssignmentForLoopIterators().

◆ getMeshShardingAttr() [1/2]

FailureOr< std::pair< bool, MeshShardingAttr > > mlir::mesh::getMeshShardingAttr ( OpOperand opOperand)

◆ getMeshShardingAttr() [2/2]

FailureOr< std::pair< bool, MeshShardingAttr > > mlir::mesh::getMeshShardingAttr ( OpResult  result)

Definition at line 99 of file ShardingInterface.cpp.

References mlir::failure(), mlir::Value::getUsers(), and mlir::Value::hasOneUse().

Referenced by addShardOp().

◆ getOperandShardings()

static SmallVector<MeshShardingAttr> mlir::mesh::getOperandShardings ( Operation op)
static

Definition at line 578 of file Spmdization.cpp.

◆ getReductionMeshAxes()

SmallVector< MeshAxis > mlir::mesh::getReductionMeshAxes ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< MeshAxis >>  meshAxisAssignmentForLoopIterators 
)

Definition at line 636 of file ShardingInterface.cpp.

◆ getResultShardings()

static SmallVector<MeshShardingAttr> mlir::mesh::getResultShardings ( Operation op)
static

Definition at line 598 of file Spmdization.cpp.

◆ handlePartialAxesDuringResharding()

static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> mlir::mesh::handlePartialAxesDuringResharding ( OpBuilder builder,
MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

◆ isAtLeastOneReductionIteratorSharded()

bool mlir::mesh::isAtLeastOneReductionIteratorSharded ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< MeshAxis >>  meshAxisAssignmentForLoopIterators 
)

Definition at line 623 of file ShardingInterface.cpp.

◆ isFullReplication()

bool mlir::mesh::isFullReplication ( MeshShardingAttr  attr)
inline

Definition at line 53 of file MeshOps.h.

Referenced by isValueCompatibleWithFullReplicationSharding().

◆ isReductionLoop()

bool mlir::mesh::isReductionLoop ( utils::IteratorType  iType)
inline

Definition at line 42 of file MeshOps.h.

Referenced by addShardOp().

◆ moveLastSplitAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> mlir::mesh::moveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshShardingAttr  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis,
MeshAxis  meshAxis 
)
static

◆ populateAllOpLoweringPatterns()

void mlir::mesh::populateAllOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateAllReduceEndomorphismSimplificationPatterns()

template<typename AlgebraicOp >
void mlir::mesh::populateAllReduceEndomorphismSimplificationPatterns ( RewritePatternSet patterns,
ReductionKind  reduction 
)

Definition at line 40 of file Simplifications.h.

◆ populateAllSliceOpLoweringPatterns()

void mlir::mesh::populateAllSliceOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateFoldingPatterns()

void mlir::mesh::populateFoldingPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateProcessMultiIndexOpLoweringPatterns()

void mlir::mesh::populateProcessMultiIndexOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateSimplificationPatterns()

void mlir::mesh::populateSimplificationPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 26 of file Simplifications.cpp.

References populateFoldingPatterns().

◆ registerAllOpLoweringDialects()

void mlir::mesh::registerAllOpLoweringDialects ( DialectRegistry registry)

◆ registerAllSliceOpLoweringDialects()

void mlir::mesh::registerAllSliceOpLoweringDialects ( DialectRegistry registry)

Definition at line 183 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerProcessMultiIndexOpLoweringDialects()

void mlir::mesh::registerProcessMultiIndexOpLoweringDialects ( DialectRegistry registry)

Definition at line 173 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ removeTrailingEmptySubArray()

template<typename T >
void mlir::mesh::removeTrailingEmptySubArray ( SmallVector< SmallVector< T >> &  array)

Definition at line 47 of file MeshOps.h.

Referenced by addShardOp().

◆ reshard() [1/3]

TypedValue<ShapedType> mlir::mesh::reshard ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)

Definition at line 481 of file Spmdization.cpp.

References reshardOn1DMesh().

Referenced by spmdizeOperation().

◆ reshard() [2/3]

TypedValue< ShapedType > mlir::mesh::reshard ( OpBuilder builder,
MeshOp  mesh,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue 
)

Definition at line 493 of file Spmdization.cpp.

Referenced by reshard().

◆ reshard() [3/3]

TypedValue< ShapedType > mlir::mesh::reshard ( OpBuilder builder,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue,
SymbolTableCollection symbolTableCollection 
)

Definition at line 505 of file Spmdization.cpp.

References getMesh(), and reshard().

◆ reshardingRegisterDependentDialects()

void mlir::mesh::reshardingRegisterDependentDialects ( DialectRegistry registry)

Definition at line 514 of file Spmdization.cpp.

References mlir::DialectRegistry::insert().

◆ reshardOn1DMesh()

static TypedValue<ShapedType> mlir::mesh::reshardOn1DMesh ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)
static

◆ shardDimension()

int64_t mlir::mesh::shardDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

Definition at line 101 of file MeshOps.h.

References mlir::ceilDiv().

Referenced by allToAllResultShapeInMoveLastAxis(), and shardShape().

◆ shardedBlockArgumentTypes()

SmallVector<Type> mlir::mesh::shardedBlockArgumentTypes ( Block block,
SymbolTableCollection symbolTableCollection 
)

◆ shardShapedType()

ShapedType mlir::mesh::shardShapedType ( ShapedType  shape,
MeshOp  mesh,
MeshShardingAttr  sharding 
)

◆ shardType()

Type mlir::mesh::shardType ( Type  type,
MeshOp  mesh,
MeshShardingAttr  sharding 
)

Definition at line 171 of file MeshOps.cpp.

References shardShapedType().

◆ splitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> mlir::mesh::splitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshShardingAttr  sourceSharding,
TypedValue< ShapedType >  sourceShard,
MeshOp  mesh,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static

◆ spmdizeBlock()

static LogicalResult mlir::mesh::spmdizeBlock ( Block block,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeFullyReplicatedOperation()

void mlir::mesh::spmdizeFullyReplicatedOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshShardingAttr >  operandShardings,
ArrayRef< MeshShardingAttr >  resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

Definition at line 553 of file ShardingInterface.cpp.

◆ spmdizeFuncOp()

static LogicalResult mlir::mesh::spmdizeFuncOp ( FunctionOpInterface  op,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection 
)
static

Definition at line 691 of file Spmdization.cpp.

◆ spmdizeOperation() [1/3]

static LogicalResult mlir::mesh::spmdizeOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshShardingAttr >  operandShardings,
ArrayRef< MeshShardingAttr >  resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

Definition at line 549 of file Spmdization.cpp.

◆ spmdizeOperation() [2/3]

static LogicalResult mlir::mesh::spmdizeOperation ( Operation op,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

Definition at line 644 of file Spmdization.cpp.

◆ spmdizeOperation() [3/3]

static LogicalResult mlir::mesh::spmdizeOperation ( ShardOp  shardOp,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeTriviallyShardableOperation()

void mlir::mesh::spmdizeTriviallyShardableOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshShardingAttr >  operandShardings,
ArrayRef< MeshShardingAttr >  resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

Definition at line 649 of file ShardingInterface.cpp.

◆ targetShardingInMoveLastAxis()

static MeshShardingAttr mlir::mesh::targetShardingInMoveLastAxis ( MLIRContext ctx,
MeshShardingAttr  sourceSharding,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

◆ targetShardingInSplitLastAxis()

static MeshShardingAttr mlir::mesh::targetShardingInSplitLastAxis ( MLIRContext ctx,
MeshShardingAttr  sourceSharding,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static

◆ targetShardingInUnsplitLastAxis()

static MeshShardingAttr mlir::mesh::targetShardingInUnsplitLastAxis ( MLIRContext ctx,
MeshShardingAttr  sourceSharding,
int64_t  splitTensorAxis 
)
static

◆ tryMoveLastSplitAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr> > mlir::mesh::tryMoveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ trySplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr> > mlir::mesh::trySplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 186 of file Spmdization.cpp.

References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().

Referenced by reshardOn1DMesh().

◆ tryUnsplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr> > mlir::mesh::tryUnsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshShardingAttr  sourceSharding,
MeshShardingAttr  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 285 of file Spmdization.cpp.

References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().

Referenced by reshardOn1DMesh().

◆ unsplitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> mlir::mesh::unsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshShardingAttr  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
MeshOp  mesh,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static