MLIR  22.0.0git
Namespaces | Classes | Typedefs | Enumerations | Functions
mlir::shard Namespace Reference

Namespaces

 detail
 

Classes

struct  ShardingOption
 
struct  IndependentParallelIteratorDomainShardingInterface
 
struct  ElementwiseShardingInterface
 
class  Sharding
 
struct  OpRewritePatternWithSymbolTableCollection
 

Typedefs

using ShardingArray = SmallVector< SmallVector< GridAxis > >
 
using ShardingArrayRef = ArrayRef< SmallVector< GridAxis > >
 
using GridAxis = int16_t
 
using GridAxesAttr = DenseI16ArrayAttr
 
using ShardShapeAttr = DenseI64ArrayAttr
 
using HaloSizePairAttr = DenseI64ArrayAttr
 
using UnshardedToShardedValueMap = DenseMap< Value, Value >
 

Enumerations

enum class  TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward }
 This enum controls the traversal order for the sharding propagation. More...
 

Functions

FailureOr< std::pair< bool, Sharding > > getSharding (OpResult result)
 
FailureOr< std::pair< bool, Sharding > > getSharding (OpOperand &opOperand)
 
void partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
ShardingArray getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
 
bool isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
 
SmallVector< GridAxisgetReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
 
void partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
 
bool isReductionLoop (utils::IteratorType iType)
 
template<typename T >
void removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array)
 
bool isFullReplication (Sharding sharding)
 
shard::GridOp getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
 
shard::GridOp getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
 
template<typename Op >
shard::GridOp getGrid (Op op, SymbolTableCollection &symbolTableCollection)
 
template<>
shard::GridOp getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
 
template<typename GridAxesRange , typename GridShapeRange >
int64_t collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
 
template<typename GridAxesRange >
int64_t collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid)
 
int64_t shardDimension (int64_t dimSize, int64_t shardCount)
 
int64_t gatherDimension (int64_t dimSize, int64_t shardCount)
 
ShapedType shardShapedType (ShapedType shape, GridOp grid, Sharding sharding)
 
Type shardType (Type type, GridOp grid, Sharding sharding)
 
void maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder)
 
void maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder)
 
SmallVector< ValuegetMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
 Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. More...
 
TypedValue< ShapedType > reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
 
TypedValue< ShapedType > reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
 
void reshardingRegisterDependentDialects (DialectRegistry &registry)
 
template<typename AlgebraicOp >
void populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction)
 
void populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerProcessMultiIndexOpLoweringDialects (DialectRegistry &registry)
 
void populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllSliceOpLoweringDialects (DialectRegistry &registry)
 
void populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
 
void registerAllOpLoweringDialects (DialectRegistry &registry)
 
TypedValue< IndexType > createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
 
TypedValue< IndexType > createProcessLinearIndex (StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
 
TypedValue< IndexType > createProcessLinearIndex (StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
 
template<typename SourceAxes , typename TargetAxes >
static bool arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
 
static Sharding targetShardingInSplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
 
static std::tuple< TypedValue< ShapedType >, ShardingsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
 
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
 
static Sharding targetShardingInUnsplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
 
static ShapedType allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, ShardingunsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
 
static Sharding targetShardingInMoveLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static ShapedType allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
 
static std::tuple< TypedValue< ShapedType >, ShardingmoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > reshardOn1DGrid (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
static TypedValue< ShapedType > reshard (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
 
static SmallVector< TypeshardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
 
static LogicalResult partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static std::vector< ShardinggetOperandShardings (Operation &op)
 
static std::vector< ShardinggetResultShardings (Operation &op)
 
static LogicalResult partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
 
static LogicalResult partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
 

Typedef Documentation

◆ GridAxesAttr

Definition at line 27 of file ShardOps.h.

◆ GridAxis

using mlir::shard::GridAxis = typedef int16_t

Definition at line 26 of file ShardOps.h.

◆ HaloSizePairAttr

Definition at line 29 of file ShardOps.h.

◆ ShardingArray

Definition at line 25 of file ShardingInterface.h.

◆ ShardingArrayRef

Definition at line 26 of file ShardingInterface.h.

◆ ShardShapeAttr

Definition at line 28 of file ShardOps.h.

◆ UnshardedToShardedValueMap

Definition at line 533 of file Partition.cpp.

Enumeration Type Documentation

◆ TraversalOrder

This enum controls the traversal order for the sharding propagation.

Enumerator
Forward 

Forward traversal.

Backward 

Backward traversal.

ForwardBackward 

Forward then backward traversal.

BackwardForward 

Backward then forward traversal.

Definition at line 23 of file Passes.h.

Function Documentation

◆ allGatherResultShapeInUnsplitLastAxis()

static ShapedType mlir::shard::allGatherResultShapeInUnsplitLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  splitTensorAxis 
)
static

Definition at line 180 of file Partition.cpp.

References gatherDimension().

Referenced by unsplitLastAxisInResharding().

◆ allToAllResultShapeInMoveLastAxis()

static ShapedType mlir::shard::allToAllResultShapeInMoveLastAxis ( ShapedType  sourceShape,
int64_t  splitCount,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

Definition at line 304 of file Partition.cpp.

References gatherDimension(), and shardDimension().

Referenced by moveLastSplitAxisInResharding().

◆ arePartialAxesCompatible()

template<typename SourceAxes , typename TargetAxes >
static bool mlir::shard::arePartialAxesCompatible ( const SourceAxes &  sourceAxes,
const TargetAxes &  targetAxes 
)
static

Definition at line 39 of file Partition.cpp.

◆ collectiveProcessGroupSize() [1/2]

template<typename GridAxesRange >
int64_t mlir::shard::collectiveProcessGroupSize ( GridAxesRange &&  gridAxes,
GridOp  grid 
)

Definition at line 162 of file ShardOps.h.

References collectiveProcessGroupSize().

◆ collectiveProcessGroupSize() [2/2]

template<typename GridAxesRange , typename GridShapeRange >
int64_t mlir::shard::collectiveProcessGroupSize ( GridAxesRange &&  gridAxes,
GridShapeRange &&  gridShape 
)

◆ createCollectiveProcessGroupSize()

TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize ( GridOp  grid,
ArrayRef< GridAxis axes,
ImplicitLocOpBuilder builder 
)

◆ createProcessLinearIndex() [1/2]

TypedValue< IndexType > mlir::shard::createProcessLinearIndex ( StringRef  grid,
ArrayRef< GridAxis gridAxes,
ImplicitLocOpBuilder builder 
)

◆ createProcessLinearIndex() [2/2]

TypedValue< IndexType > mlir::shard::createProcessLinearIndex ( StringRef  grid,
ValueRange  processInGroupMultiIndex,
ArrayRef< GridAxis gridAxes,
ImplicitLocOpBuilder builder 
)

◆ detectMoveLastSplitAxisInResharding()

static std::optional<std::tuple<int64_t, int64_t, GridAxis> > mlir::shard::detectMoveLastSplitAxisInResharding ( Sharding  sourceSharding,
Sharding  targetSharding 
)
static

Definition at line 235 of file Partition.cpp.

References mlir::shard::Sharding::getSplitAxes().

Referenced by tryMoveLastSplitAxisInResharding().

◆ detectSplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, GridAxis> > mlir::shard::detectSplitLastAxisInResharding ( Sharding  sourceSharding,
Sharding  targetSharding 
)
static

Definition at line 87 of file Partition.cpp.

References mlir::shard::Sharding::getSplitAxes().

Referenced by trySplitLastAxisInResharding().

◆ detectUnsplitLastAxisInResharding()

static std::optional<std::tuple<int64_t, GridAxis> > mlir::shard::detectUnsplitLastAxisInResharding ( Sharding  sourceSharding,
Sharding  targetSharding 
)
static

Definition at line 136 of file Partition.cpp.

References mlir::shard::Sharding::getSplitAxes().

Referenced by tryUnsplitLastAxisInResharding().

◆ gatherDimension()

int64_t mlir::shard::gatherDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

◆ getGrid() [1/2]

template<typename Op >
shard::GridOp mlir::shard::getGrid ( Op  op,
SymbolTableCollection symbolTableCollection 
)

Definition at line 130 of file ShardOps.h.

References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().

◆ getGrid() [2/2]

shard::GridOp mlir::shard::getGrid ( Operation op,
FlatSymbolRefAttr  gridSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

◆ getGrid< ShardOp >()

template<>
shard::GridOp mlir::shard::getGrid< ShardOp > ( ShardOp  op,
SymbolTableCollection symbolTableCollection 
)
inline

Definition at line 135 of file ShardOps.h.

References getGrid().

◆ getGridAxisAssignmentForLoopIterators()

ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators ( ArrayRef< Sharding operandShardings,
ArrayRef< Sharding resultShardings,
ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< AffineMap indexingMaps 
)

Definition at line 572 of file ShardingInterface.cpp.

References updateGridAxisAssignmentForLoopIterators().

◆ getGridOrNull()

shard::GridOp mlir::shard::getGridOrNull ( Operation op,
FlatSymbolRefAttr  gridSymbol,
SymbolTableCollection symbolTableCollection 
)
inline

◆ getMixedAsValues()

SmallVector< Value > mlir::shard::getMixedAsValues ( OpBuilder  b,
const Location loc,
llvm::ArrayRef< int64_t >  statics,
ValueRange  dynamics,
Type  type = Type() 
)

Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.

Definition at line 77 of file ShardOps.cpp.

References mlir::Builder::getI64IntegerAttr(), mlir::Builder::getI64Type(), mlir::Builder::getIndexAttr(), and mlir::Builder::getIndexType().

◆ getOperandShardings()

static std::vector<Sharding> mlir::shard::getOperandShardings ( Operation op)
static

◆ getReductionGridAxes()

SmallVector< GridAxis > mlir::shard::getReductionGridAxes ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< GridAxis >>  gridAxisAssignmentForLoopIterators 
)

◆ getResultShardings()

static std::vector<Sharding> mlir::shard::getResultShardings ( Operation op)
static

◆ getSharding() [1/2]

FailureOr< std::pair< bool, Sharding > > mlir::shard::getSharding ( OpOperand opOperand)

◆ getSharding() [2/2]

FailureOr< std::pair< bool, Sharding > > mlir::shard::getSharding ( OpResult  result)

◆ isAtLeastOneReductionIteratorSharded()

bool mlir::shard::isAtLeastOneReductionIteratorSharded ( ArrayRef< utils::IteratorType >  loopIteratorTypes,
ArrayRef< SmallVector< GridAxis >>  gridAxisAssignmentForLoopIterators 
)

Definition at line 612 of file ShardingInterface.cpp.

◆ isFullReplication()

bool mlir::shard::isFullReplication ( Sharding  sharding)
inline

◆ isReductionLoop()

bool mlir::shard::isReductionLoop ( utils::IteratorType  iType)
inline

Definition at line 94 of file ShardOps.h.

◆ maybeInsertSourceShardingAnnotation()

void mlir::shard::maybeInsertSourceShardingAnnotation ( Sharding  sharding,
OpOperand operand,
OpBuilder builder 
)

◆ maybeInsertTargetShardingAnnotation()

void mlir::shard::maybeInsertTargetShardingAnnotation ( Sharding  sharding,
OpResult  result,
OpBuilder builder 
)

Definition at line 338 of file ShardOps.cpp.

References mlir::Value::getUses(), and maybeInsertTargetShardingAnnotationImpl().

Referenced by addShardOp().

◆ moveLastSplitAxisInResharding()

static std::tuple<TypedValue<ShapedType>, Sharding> mlir::shard::moveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis,
GridAxis  gridAxis 
)
static

◆ partitionBlock()

static LogicalResult mlir::shard::partitionBlock ( Block block,
IRMapping partitionMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ partitionFullyReplicatedOperation()

void mlir::shard::partitionFullyReplicatedOperation ( Operation op,
ArrayRef< Value partitionedOperands,
ArrayRef< Sharding operandShardings,
ArrayRef< Sharding resultShardings,
IRMapping partitionMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

◆ partitionFuncOp()

static LogicalResult mlir::shard::partitionFuncOp ( FunctionOpInterface  op,
IRMapping partitionMap,
SymbolTableCollection symbolTableCollection 
)
static

◆ partitionOperation() [1/3]

static LogicalResult mlir::shard::partitionOperation ( Operation op,
ArrayRef< Value partitionedOperands,
ArrayRef< Sharding operandShardings,
ArrayRef< Sharding resultShardings,
IRMapping partitionMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ partitionOperation() [2/3]

static LogicalResult mlir::shard::partitionOperation ( Operation op,
IRMapping partitionMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ partitionOperation() [3/3]

static LogicalResult mlir::shard::partitionOperation ( ShardOp  shardOp,
IRMapping partitionMap,
SymbolTableCollection symbolTableCollection,
OpBuilder builder 
)
static

◆ partitionTriviallyShardableOperation()

void mlir::shard::partitionTriviallyShardableOperation ( Operation op,
ArrayRef< Value partitionedOperands,
ArrayRef< Sharding operandShardings,
ArrayRef< Sharding resultShardings,
IRMapping partitionMap,
SymbolTableCollection symbolTable,
OpBuilder builder 
)

◆ populateAllOpLoweringPatterns()

void mlir::shard::populateAllOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

◆ populateAllReduceEndomorphismSimplificationPatterns()

template<typename AlgebraicOp >
void mlir::shard::populateAllReduceEndomorphismSimplificationPatterns ( RewritePatternSet patterns,
ReductionKind  reduction 
)

Definition at line 40 of file Simplifications.h.

References mlir::patterns.

◆ populateAllSliceOpLoweringPatterns()

void mlir::shard::populateAllSliceOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 178 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateFoldingPatterns()

void mlir::shard::populateFoldingPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 114 of file Simplifications.cpp.

References mlir::patterns.

Referenced by populateSimplificationPatterns().

◆ populateProcessMultiIndexOpLoweringPatterns()

void mlir::shard::populateProcessMultiIndexOpLoweringPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 168 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateSimplificationPatterns()

void mlir::shard::populateSimplificationPatterns ( RewritePatternSet patterns,
SymbolTableCollection symbolTableCollection 
)

Definition at line 23 of file Simplifications.cpp.

References mlir::patterns, and populateFoldingPatterns().

◆ registerAllOpLoweringDialects()

void mlir::shard::registerAllOpLoweringDialects ( DialectRegistry registry)

◆ registerAllSliceOpLoweringDialects()

void mlir::shard::registerAllSliceOpLoweringDialects ( DialectRegistry registry)

Definition at line 184 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerProcessMultiIndexOpLoweringDialects()

void mlir::shard::registerProcessMultiIndexOpLoweringDialects ( DialectRegistry registry)

Definition at line 174 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ removeTrailingEmptySubArray()

template<typename T >
void mlir::shard::removeTrailingEmptySubArray ( SmallVector< SmallVector< T >> &  array)

Definition at line 100 of file ShardOps.h.

Referenced by mlir::shard::detail::defaultGetShardingOption(), and getSharding().

◆ reshard() [1/3]

static TypedValue<ShapedType> mlir::shard::reshard ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
Sharding  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 481 of file Partition.cpp.

References isFullReplication(), reshardOn1DGrid(), and tryUpdateHaloInResharding().

Referenced by partitionOperation().

◆ reshard() [2/3]

TypedValue< ShapedType > mlir::shard::reshard ( OpBuilder builder,
GridOp  grid,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue 
)

Definition at line 505 of file Partition.cpp.

Referenced by reshard().

◆ reshard() [3/3]

TypedValue< ShapedType > mlir::shard::reshard ( OpBuilder builder,
ShardOp  source,
ShardOp  target,
TypedValue< ShapedType >  sourceShardValue,
SymbolTableCollection symbolTableCollection 
)

Definition at line 517 of file Partition.cpp.

References getGrid(), and reshard().

◆ reshardingRegisterDependentDialects()

void mlir::shard::reshardingRegisterDependentDialects ( DialectRegistry registry)

Definition at line 526 of file Partition.cpp.

References mlir::DialectRegistry::insert().

◆ reshardOn1DGrid()

static TypedValue<ShapedType> mlir::shard::reshardOn1DGrid ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
Sharding  targetSharding,
TypedValue< ShapedType >  sourceUnshardedValue,
TypedValue< ShapedType >  sourceShard 
)
static

◆ shardDimension()

int64_t mlir::shard::shardDimension ( int64_t  dimSize,
int64_t  shardCount 
)
inline

Definition at line 168 of file ShardOps.h.

Referenced by allToAllResultShapeInMoveLastAxis().

◆ shardedBlockArgumentTypes()

static SmallVector<Type> mlir::shard::shardedBlockArgumentTypes ( Block block,
SymbolTableCollection symbolTableCollection 
)
static

◆ shardShapedType()

ShapedType mlir::shard::shardShapedType ( ShapedType  shape,
GridOp  grid,
Sharding  sharding 
)

◆ shardType()

Type mlir::shard::shardType ( Type  type,
GridOp  grid,
Sharding  sharding 
)

Definition at line 291 of file ShardOps.cpp.

References shardShapedType().

Referenced by partitionTriviallyShardableOperation().

◆ splitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, Sharding> mlir::shard::splitLastAxisInResharding ( ImplicitLocOpBuilder builder,
Sharding  sourceSharding,
TypedValue< ShapedType >  sourceShard,
GridOp  grid,
int64_t  splitTensorAxis,
GridAxis  splitGridAxis 
)
static

◆ targetShardingInMoveLastAxis()

static Sharding mlir::shard::targetShardingInMoveLastAxis ( MLIRContext ctx,
Sharding  sourceSharding,
int64_t  sourceTensorAxis,
int64_t  targetTensorAxis 
)
static

◆ targetShardingInSplitLastAxis()

static Sharding mlir::shard::targetShardingInSplitLastAxis ( MLIRContext ctx,
Sharding  sourceSharding,
int64_t  splitTensorAxis,
GridAxis  splitGridAxis 
)
static

◆ targetShardingInUnsplitLastAxis()

static Sharding mlir::shard::targetShardingInUnsplitLastAxis ( MLIRContext ctx,
Sharding  sourceSharding,
int64_t  splitTensorAxis 
)
static

◆ tryMoveLastSplitAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, Sharding> > mlir::shard::tryMoveLastSplitAxisInResharding ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
Sharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ trySplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, Sharding> > mlir::shard::trySplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
Sharding  targetSharding,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 119 of file Partition.cpp.

References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().

Referenced by reshardOn1DGrid().

◆ tryUnsplitLastAxisInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, Sharding> > mlir::shard::tryUnsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
Sharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

Definition at line 214 of file Partition.cpp.

References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().

Referenced by reshardOn1DGrid().

◆ tryUpdateHaloInResharding()

static std::optional<std::tuple<TypedValue<ShapedType>, Sharding> > mlir::shard::tryUpdateHaloInResharding ( ImplicitLocOpBuilder builder,
GridOp  grid,
Sharding  sourceSharding,
Sharding  targetSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard 
)
static

◆ unsplitLastAxisInResharding()

static std::tuple<TypedValue<ShapedType>, Sharding> mlir::shard::unsplitLastAxisInResharding ( ImplicitLocOpBuilder builder,
Sharding  sourceSharding,
ShapedType  sourceUnshardedShape,
TypedValue< ShapedType >  sourceShard,
GridOp  grid,
int64_t  splitTensorAxis,
GridAxis  splitGridAxis 
)
static