MLIR  22.0.0git
Namespaces | Classes | Typedefs | Enumerations | Functions
mlir::mesh Namespace Reference

Namespaces

 detail
 

Classes

struct  ShardingOption
 
struct  IndependentParallelIteratorDomainShardingInterface
 
struct  ElementwiseShardingInterface
 
class  MeshSharding
 
struct  OpRewritePatternWithSymbolTableCollection
 

Typedefs

using ShardingArray = SmallVector< SmallVector< MeshAxis > >
 
using ShardingArrayRef = ArrayRef< SmallVector< MeshAxis > >
 
using MeshAxis = int16_t
 
using MeshAxesAttr = DenseI16ArrayAttr
 
using ShardShapeAttr = DenseI64ArrayAttr
 
using HaloSizePairAttr = DenseI64ArrayAttr
 
using UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Enumerations

enum class  TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward }
 This enum controls the traversal order for the sharding propagation. More...
 

Functions

FailureOr< std::pair< bool, MeshSharding > > getMeshSharding (OpResult result)
 
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding (OpOperand &opOperand)
 
void spmdizeFullyReplicatedOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
ShardingArray getMeshAxisAssignmentForLoopIterators (ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
 
bool isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
 
SmallVector< MeshAxisgetReductionMeshAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
 
void spmdizeTriviallyShardableOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
bool isReductionLoop (utils::IteratorType iType)
 
template<typename T >
void removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array)
 
bool isFullReplication (MeshSharding sharding)
 
mesh::MeshOp getMeshOrNull (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
 
mesh::MeshOp getMesh (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
 
template<typename Op >
mesh::MeshOp getMesh (Op op, SymbolTableCollection &symbolTableCollection)
 
template<>
mesh::MeshOp getMesh< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
 
template<typename MeshAxesRange , typename MeshShapeRange >
int64_t collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
 
template<typename MeshAxesRange >
int64_t collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshOp mesh)
 
int64_t shardDimension (int64_t dimSize, int64_t shardCount)
 
int64_t gatherDimension (int64_t dimSize, int64_t shardCount)
 
ShapedType shardShapedType (ShapedType shape, MeshOp mesh, MeshSharding sharding)
 
Type shardType (Type type, MeshOp mesh, MeshSharding sharding)
 
void maybeInsertTargetShardingAnnotation (MeshSharding sharding, OpResult result, OpBuilder &builder)
 
void maybeInsertSourceShardingAnnotation (MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
 
SmallVector< ValuegetMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
 Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. More...
 
template<typename AlgebraicOp >
void populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction)
 
void populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void reshardingRegisterDependentDialects (DialectRegistry &registry)
 
void populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerProcessMultiIndexOpLoweringDialects (DialectRegistry &registry)
 
void populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllSliceOpLoweringDialects (DialectRegistry &registry)
 
void populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllOpLoweringDialects (DialectRegistry &registry)
 
TypedValue< IndexType > createCollectiveProcessGroupSize (MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder)
 
TypedValue< IndexType > createProcessLinearIndex (StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
 
TypedValue< IndexType > createProcessLinearIndex (StringRef mesh, ValueRange processInGroupMultiIndex, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder)
 
template<typename SourceAxes , typename TargetAxes >
static bool arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static std::tuple< TypedValue< ShapedType >, MeshShardinghandlePartialAxesDuringResharding (OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static MeshSharding targetShardingInSplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static MeshSharding targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
 
static ShapedType allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingunsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding (MeshSharding sourceSharding, MeshSharding targetSharding)
 
static MeshSharding targetShardingInMoveLastAxis (MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, MeshShardingmoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
TypedValue< ShapedType > reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
SmallVector< TypeshardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
static LogicalResult spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static std::vector< MeshShardinggetOperandShardings (Operation &op)
 
static std::vector< MeshShardinggetResultShardings (Operation &op)
 
static LogicalResult spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
 

Typedef Documentation

◆ HaloSizePairAttr

Definition at line 29 of file MeshOps.h.

◆ MeshAxesAttr

Definition at line 27 of file MeshOps.h.

◆ MeshAxis

using mlir::mesh::MeshAxis = typedef int16_t

Definition at line 26 of file MeshOps.h.

◆ ShardingArray

Definition at line 25 of file ShardingInterface.h.

◆ ShardingArrayRef

Definition at line 26 of file ShardingInterface.h.

◆ ShardShapeAttr

Definition at line 28 of file MeshOps.h.

◆ UnshardedToShardedValueMap

Definition at line 610 of file Spmdization.cpp.

Enumeration Type Documentation

◆ TraversalOrder

This enum controls the traversal order for the sharding propagation.

Enumerator
Forward 

Forward traversal.

Backward 

Backward traversal.

ForwardBackward 

Forward then backward traversal.

BackwardForward 

Backward then forward traversal.

Definition at line 23 of file Passes.h.

Function Documentation

◆ allGatherResultShapeInUnsplitLastAxis()

static ShapedType mlir::mesh::allGatherResultShapeInUnsplitLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  splitTensorAxis 
)
static

Definition at line 247 of file Spmdization.cpp.

References gatherDimension().

Referenced by unsplitLastAxisInResharding().

◆ allToAllResultShapeInMoveLastAxis()

static ShapedType mlir::mesh::allToAllResultShapeInMoveLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

Definition at line 374 of file Spmdization.cpp.

References gatherDimension(), and shardDimension().

Referenced by moveLastSplitAxisInResharding().

◆ arePartialAxesCompatible()

template<typename SourceAxes , typename TargetAxes >
static bool mlir::mesh::arePartialAxesCompatible ( const SourceAxes &  sourceAxes,
const TargetAxes &  targetAxes 
)
static

Definition at line 42 of file Spmdization.cpp.

Referenced by handlePartialAxesDuringResharding().

◆ collectiveProcessGroupSize() [1/2]

template<typename MeshAxesRange >
int64_t mlir::mesh::collectiveProcessGroupSize ( MeshAxesRange &&  meshAxes,
MeshOp  mesh 
)

Definition at line 169 of file MeshOps.h.

References collectiveProcessGroupSize().

◆ collectiveProcessGroupSize() [2/2]

template<typename MeshAxesRange , typename MeshShapeRange >
int64_t mlir::mesh::collectiveProcessGroupSize ( MeshAxesRange &&  meshAxes,
MeshShapeRange &&  meshShape 
)

◆ createCollectiveProcessGroupSize()

TypedValue< IndexType > mlir::mesh::createCollectiveProcessGroupSize ( MeshOp  mesh,
ArrayRef< MeshAxis axes,
ImplicitLocOpBuilder builder 
)

◆ createProcessLinearIndex() [1/2]

TypedValue< IndexType > mlir::mesh::createProcessLinearIndex ( StringRef  mesh,
ArrayRef< MeshAxis meshAxes,
ImplicitLocOpBuilder builder 
)

◆ createProcessLinearIndex() [2/2]

TypedValue< IndexType > mlir::mesh::createProcessLinearIndex ( StringRef  mesh,
ValueRange  processInGroupMultiIndex,
ArrayRef< MeshAxis meshAxes,
ImplicitLocOpBuilder builder 
)

◆ detectMoveLastSplitAxisInResharding()

static std::optional<std::tuple<int64_t, int64_t, MeshAxis> > mlir::mesh::detectMoveLastSplitAxisInResharding ( MeshSharding  sourceSharding,
MeshSharding  targetSharding 
)
static

◆ detectSplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, MeshAxis> > mlir::mesh::detectSplitLastAxisInResharding ( MeshSharding  sourceSharding,
MeshSharding  targetSharding 
)
static

Definition at line 151 of file Spmdization.cpp.

References mlir::mesh::MeshSharding::getSplitAxes().

Referenced by trySplitLastAxisInResharding().

◆ detectUnsplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, MeshAxis> > mlir::mesh::detectUnsplitLastAxisInResharding ( MeshSharding  sourceSharding,
MeshSharding  targetSharding 
)
static

Definition at line 201 of file Spmdization.cpp.

References mlir::mesh::MeshSharding::getSplitAxes().

Referenced by tryUnsplitLastAxisInResharding().

◆ gatherDimension()

int64_t mlir::mesh::gatherDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

◆ getMesh() [1/2]

template<typename Op >
mesh::MeshOp mlir::mesh::getMesh ( Op  op,
SymbolTableCollection symbolTableCollection 
)

Definition at line 137 of file MeshOps.h.

References getMesh(), and mlir::Op< ConcreteType, Traits >::getOperation().

◆ getMesh() [2/2]

mesh::MeshOp mlir::mesh::getMesh ( Operation op,
FlatSymbolRefAttr  meshSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

◆ getMesh< ShardOp >()

template<>
mesh::MeshOp mlir::mesh::getMesh< ShardOp > ( ShardOp  op,
SymbolTableCollection symbolTableCollection 
)
inline

Definition at line 142 of file MeshOps.h.

References getMesh().

◆ getMeshAxisAssignmentForLoopIterators()

ShardingArray mlir::mesh::getMeshAxisAssignmentForLoopIterators ( ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< AffineMap indexingMaps 
)

Definition at line 642 of file ShardingInterface.cpp.

References updateMeshAxisAssignmentForLoopIterators().

◆ getMeshOrNull()

mesh::MeshOp mlir::mesh::getMeshOrNull ( Operation op,
FlatSymbolRefAttr  meshSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

◆ getMeshSharding() [1/2]

FailureOr< std::pair< bool, MeshSharding > > mlir::mesh::getMeshSharding ( OpOperand opOperand)

◆ getMeshSharding() [2/2]

FailureOr< std::pair< bool, MeshSharding > > mlir::mesh::getMeshSharding ( OpResult  result)

Definition at line 110 of file ShardingInterface.cpp.

References getSharding(), mlir::Value::getUsers(), and mlir::Value::hasOneUse().

Referenced by visitOp().

◆ getMixedAsValues()

SmallVector< Value > mlir::mesh::getMixedAsValues ( OpBuilder  b,
const Location loc,
llvm::ArrayRef< int64_t >  statics,
ValueRange  dynamics,
Type  type = Type() 
)

Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.

Definition at line 77 of file MeshOps.cpp.

References mlir::OpBuilder::create(), mlir::Builder::getI64IntegerAttr(), mlir::Builder::getI64Type(), mlir::Builder::getIndexAttr(), and mlir::Builder::getIndexType().

◆ getOperandShardings()

static std::vector<MeshSharding> mlir::mesh::getOperandShardings ( Operation op)
static

◆ getReductionMeshAxes()

SmallVector< MeshAxis > mlir::mesh::getReductionMeshAxes ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< MeshAxis >>  meshAxisAssignmentForLoopIterators 
)

◆ getResultShardings()

static std::vector<MeshSharding> mlir::mesh::getResultShardings ( Operation op)
static

◆ handlePartialAxesDuringResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::handlePartialAxesDuringResharding ( OpBuilder builder,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

◆ isAtLeastOneReductionIteratorSharded()

bool mlir::mesh::isAtLeastOneReductionIteratorSharded ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< MeshAxis >>  meshAxisAssignmentForLoopIterators 
)

Definition at line 683 of file ShardingInterface.cpp.

◆ isFullReplication()

bool mlir::mesh::isFullReplication ( MeshSharding  sharding)
inline

◆ isReductionLoop()

bool mlir::mesh::isReductionLoop ( utils::IteratorType  iType)
inline

Definition at line 100 of file MeshOps.h.

Referenced by mlir::mesh::detail::defaultGetShardingOption(), and getSharding().

◆ maybeInsertSourceShardingAnnotation()

void mlir::mesh::maybeInsertSourceShardingAnnotation ( MeshSharding  sharding,
OpOperand operand,
OpBuilder builder 
)

◆ maybeInsertTargetShardingAnnotation()

void mlir::mesh::maybeInsertTargetShardingAnnotation ( MeshSharding  sharding,
OpResult  result,
OpBuilder builder 
)

Definition at line 339 of file MeshOps.cpp.

References mlir::Value::getUses(), and maybeInsertTargetShardingAnnotationImpl().

Referenced by addShardOp().

◆ moveLastSplitAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::moveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis,
MeshAxis  meshAxis 
)
static

◆ populateAllOpLoweringPatterns()

void mlir::mesh::populateAllOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateAllReduceEndomorphismSimplificationPatterns()

template<typename AlgebraicOp >
void mlir::mesh::populateAllReduceEndomorphismSimplificationPatterns ( RewritePatternSet patterns,
ReductionKind  reduction 
)

Definition at line 40 of file Simplifications.h.

References mlir::patterns.

◆ populateAllSliceOpLoweringPatterns()

void mlir::mesh::populateAllSliceOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 177 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateFoldingPatterns()

void mlir::mesh::populateFoldingPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 115 of file Simplifications.cpp.

References mlir::patterns.

Referenced by populateSimplificationPatterns().

◆ populateProcessMultiIndexOpLoweringPatterns()

void mlir::mesh::populateProcessMultiIndexOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 167 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateSimplificationPatterns()

void mlir::mesh::populateSimplificationPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 24 of file Simplifications.cpp.

References mlir::patterns, and populateFoldingPatterns().

◆ registerAllOpLoweringDialects()

void mlir::mesh::registerAllOpLoweringDialects ( DialectRegistry registry)

◆ registerAllSliceOpLoweringDialects()

void mlir::mesh::registerAllSliceOpLoweringDialects ( DialectRegistry registry)

Definition at line 183 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerProcessMultiIndexOpLoweringDialects()

void mlir::mesh::registerProcessMultiIndexOpLoweringDialects ( DialectRegistry registry)

Definition at line 173 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ removeTrailingEmptySubArray()

template<typename T >
void mlir::mesh::removeTrailingEmptySubArray ( SmallVector< SmallVector< T >> &  array)

Definition at line 106 of file MeshOps.h.

Referenced by mlir::mesh::detail::defaultGetShardingOption(), and getSharding().

◆ reshard() [1/3]

TypedValue<ShapedType> mlir::mesh::reshard ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)

Definition at line 556 of file Spmdization.cpp.

References isFullReplication(), reshardOn1DMesh(), and tryUpdateHaloInResharding().

Referenced by spmdizeOperation().

◆ reshard() [2/3]

TypedValue< ShapedType > mlir::mesh::reshard ( OpBuilder builder,
MeshOp  mesh,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue 
)

Definition at line 582 of file Spmdization.cpp.

Referenced by reshard().

◆ reshard() [3/3]

TypedValue< ShapedType > mlir::mesh::reshard ( OpBuilder builder,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue,
SymbolTableCollection symbolTableCollection 
)

Definition at line 594 of file Spmdization.cpp.

References getMesh(), and reshard().

◆ reshardingRegisterDependentDialects()

void mlir::mesh::reshardingRegisterDependentDialects ( DialectRegistry registry)

Definition at line 603 of file Spmdization.cpp.

References mlir::DialectRegistry::insert().

◆ reshardOn1DMesh()

static TypedValue<ShapedType> mlir::mesh::reshardOn1DMesh ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)
static

◆ shardDimension()

int64_t mlir::mesh::shardDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

Definition at line 175 of file MeshOps.h.

Referenced by allToAllResultShapeInMoveLastAxis().

◆ shardedBlockArgumentTypes()

SmallVector<Type> mlir::mesh::shardedBlockArgumentTypes ( Block block,
SymbolTableCollection symbolTableCollection 
)

◆ shardShapedType()

ShapedType mlir::mesh::shardShapedType ( ShapedType  shape,
MeshOp  mesh,
MeshSharding  sharding 
)

◆ shardType()

Type mlir::mesh::shardType ( Type  type,
MeshOp  mesh,
MeshSharding  sharding 
)

Definition at line 292 of file MeshOps.cpp.

References shardShapedType().

Referenced by spmdizeTriviallyShardableOperation().

◆ splitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::splitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshSharding  sourceSharding,
TypedValue< ShapedType >  sourceShard,
MeshOp  mesh,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static

◆ spmdizeBlock()

static LogicalResult mlir::mesh::spmdizeBlock ( Block block,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeFullyReplicatedOperation()

void mlir::mesh::spmdizeFullyReplicatedOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

◆ spmdizeFuncOp()

static LogicalResult mlir::mesh::spmdizeFuncOp ( FunctionOpInterface  op,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection 
)
static

Definition at line 807 of file Spmdization.cpp.

References mlir::get(), mlir::Operation::getOperandTypes(), and spmdizeBlock().

◆ spmdizeOperation() [1/3]

static LogicalResult mlir::mesh::spmdizeOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeOperation() [2/3]

static LogicalResult mlir::mesh::spmdizeOperation ( Operation op,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeOperation() [3/3]

static LogicalResult mlir::mesh::spmdizeOperation ( ShardOp  shardOp,
IRMapping spmdizationMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ spmdizeTriviallyShardableOperation()

void mlir::mesh::spmdizeTriviallyShardableOperation ( Operation op,
ArrayRef< Value spmdizedOperands,
ArrayRef< MeshSharding operandShardings,
ArrayRef< MeshSharding resultShardings,
IRMapping spmdizationMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

◆ targetShardingInMoveLastAxis()

static MeshSharding mlir::mesh::targetShardingInMoveLastAxis ( MLIRContext ctx,
MeshSharding  sourceSharding,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

◆ targetShardingInSplitLastAxis()

static MeshSharding mlir::mesh::targetShardingInSplitLastAxis ( MLIRContext ctx,
MeshSharding  sourceSharding,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static

◆ targetShardingInUnsplitLastAxis()

static MeshSharding mlir::mesh::targetShardingInUnsplitLastAxis ( MLIRContext ctx,
MeshSharding  sourceSharding,
int64_t  splitTensorAxis 
)
static

◆ tryMoveLastSplitAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::tryMoveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ trySplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::trySplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 183 of file Spmdization.cpp.

References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().

Referenced by reshardOn1DMesh().

◆ tryUnsplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::tryUnsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 281 of file Spmdization.cpp.

References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().

Referenced by reshardOn1DMesh().

◆ tryUpdateHaloInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding> > mlir::mesh::tryUpdateHaloInResharding ( ImplicitLocOpBuilder builder,
MeshOp  mesh,
MeshSharding  sourceSharding,
MeshSharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ unsplitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, MeshSharding> mlir::mesh::unsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
MeshSharding  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
MeshOp  mesh,
int64_t  splitTensorAxis,
MeshAxis  splitMeshAxis 
)
static