MLIR 22.0.0git
mlir::shard Namespace Reference

Namespaces

namespace  detail
namespace  impl

Classes

struct  ElementwiseShardingInterface
struct  IndependentParallelIteratorDomainShardingInterface
struct  OpRewritePatternWithSymbolTableCollection
class  Sharding
struct  ShardingOption
struct  ShardingPropagationOptions

Typedefs

using ShardingArray = SmallVector<SmallVector<GridAxis>>
using ShardingArrayRef = ArrayRef<SmallVector<GridAxis>>
using GridAxis = int16_t
using GridAxesAttr = DenseI16ArrayAttr
using ShardShapeAttr = DenseI64ArrayAttr
using HaloSizePairAttr = DenseI64ArrayAttr
using UnshardedToShardedValueMap = DenseMap<Value, Value>

Enumerations

enum class  TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward }
 This enum controls the traversal order for the sharding propagation. More...

Functions

FailureOr< std::pair< bool, Sharding > > getSharding (OpResult result)
FailureOr< std::pair< bool, Sharding > > getSharding (OpOperand &opOperand)
void partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
ShardingArray getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
SmallVector< GridAxisgetReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
void partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
bool isReductionLoop (utils::IteratorType iType)
template<typename T>
void removeTrailingEmptySubArray (SmallVector< SmallVector< T > > &array)
bool isFullReplication (Sharding sharding)
shard::GridOp getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
shard::GridOp getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
template<typename Op>
shard::GridOp getGrid (Op op, SymbolTableCollection &symbolTableCollection)
template<>
shard::GridOp getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
template<typename GridAxesRange, typename GridShapeRange>
int64_t collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
template<typename GridAxesRange>
int64_t collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid)
int64_t shardDimension (int64_t dimSize, int64_t shardCount)
int64_t gatherDimension (int64_t dimSize, int64_t shardCount)
ShapedType shardShapedType (ShapedType shape, GridOp grid, Sharding sharding)
Type shardType (Type type, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder)
void maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder)
SmallVector< ValuegetMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
 Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
TypedValue< ShapedType > reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
TypedValue< ShapedType > reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
void reshardingRegisterDependentDialects (DialectRegistry &registry)
std::unique_ptr<::mlir::PasscreatePartition ()
std::unique_ptr<::mlir::PasscreateShardingPropagation ()
std::unique_ptr<::mlir::PasscreateShardingPropagation (ShardingPropagationOptions options)
void registerPartition ()
void registerPartitionPass ()
void registerShardingPropagation ()
void registerShardingPropagationPass ()
void registerShardPasses ()
template<typename AlgebraicOp>
void populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction)
void populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerProcessMultiIndexOpLoweringDialects (DialectRegistry &registry)
void populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllSliceOpLoweringDialects (DialectRegistry &registry)
void populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllOpLoweringDialects (DialectRegistry &registry)
TypedValue< IndexType > createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
TypedValue< IndexType > createProcessLinearIndex (StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
TypedValue< IndexType > createProcessLinearIndex (StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder)
template<typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
static Sharding targetShardingInSplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
static std::tuple< TypedValue< ShapedType >, ShardingsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
static Sharding targetShardingInUnsplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
static ShapedType allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, ShardingunsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding (Sharding sourceSharding, Sharding targetSharding)
static Sharding targetShardingInMoveLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static ShapedType allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::tuple< TypedValue< ShapedType >, ShardingmoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshardOn1DGrid (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshard (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static SmallVector< TypeshardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< ShardinggetOperandShardings (Operation &op)
static std::vector< ShardinggetResultShardings (Operation &op)
static LogicalResult partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)

Typedef Documentation

◆ GridAxesAttr

Definition at line 27 of file ShardOps.h.

◆ GridAxis

using mlir::shard::GridAxis = int16_t

Definition at line 26 of file ShardOps.h.

◆ HaloSizePairAttr

Definition at line 29 of file ShardOps.h.

◆ ShardingArray

◆ ShardingArrayRef

◆ ShardShapeAttr

Definition at line 28 of file ShardOps.h.

◆ UnshardedToShardedValueMap

Enumeration Type Documentation

◆ TraversalOrder

enum class mlir::shard::TraversalOrder
strong

This enum controls the traversal order for the sharding propagation.

Enumerator
Forward 

Forward traversal.

Backward 

Backward traversal.

ForwardBackward 

Forward then backward traversal.

BackwardForward 

Backward then forward traversal.

Definition at line 23 of file Passes.h.

Function Documentation

◆ allGatherResultShapeInUnsplitLastAxis()

ShapedType mlir::shard::allGatherResultShapeInUnsplitLastAxis ( ShapedType sourceShape,
int64_t splitCount,
int64_t splitTensorAxis )
static

Definition at line 180 of file Partition.cpp.

References gatherDimension().

Referenced by unsplitLastAxisInResharding().

◆ allToAllResultShapeInMoveLastAxis()

ShapedType mlir::shard::allToAllResultShapeInMoveLastAxis ( ShapedType sourceShape,
int64_t splitCount,
int64_t sourceTensorAxis,
int64_t targetTensorAxis )
static

Definition at line 303 of file Partition.cpp.

References gatherDimension(), and shardDimension().

Referenced by moveLastSplitAxisInResharding().

◆ arePartialAxesCompatible()

template<typename SourceAxes, typename TargetAxes>
bool mlir::shard::arePartialAxesCompatible ( const SourceAxes & sourceAxes,
const TargetAxes & targetAxes )
static

Definition at line 39 of file Partition.cpp.

◆ collectiveProcessGroupSize() [1/2]

template<typename GridAxesRange>
int64_t mlir::shard::collectiveProcessGroupSize ( GridAxesRange && gridAxes,
GridOp grid )

Definition at line 162 of file ShardOps.h.

References collectiveProcessGroupSize().

◆ collectiveProcessGroupSize() [2/2]

template<typename GridAxesRange, typename GridShapeRange>
int64_t mlir::shard::collectiveProcessGroupSize ( GridAxesRange && gridAxes,
GridShapeRange && gridShape )

◆ createCollectiveProcessGroupSize()

TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize ( GridOp grid,
ArrayRef< GridAxis > axes,
ImplicitLocOpBuilder & builder )

◆ createPartition()

std::unique_ptr<::mlir::Pass > mlir::shard::createPartition ( )

We declare an explicit private instantiation because Pass classes should only be visible by the current library.

Definition at line 80 of file Partition.cpp.

◆ createProcessLinearIndex() [1/2]

TypedValue< IndexType > mlir::shard::createProcessLinearIndex ( StringRef grid,
ArrayRef< GridAxis > gridAxes,
ImplicitLocOpBuilder & builder )

◆ createProcessLinearIndex() [2/2]

TypedValue< IndexType > mlir::shard::createProcessLinearIndex ( StringRef grid,
ValueRange processInGroupMultiIndex,
ArrayRef< GridAxis > gridAxes,
ImplicitLocOpBuilder & builder )

Definition at line 212 of file Transforms.cpp.

References mlir::arith::ConstantIndexOp::create().

◆ createShardingPropagation() [1/2]

std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation ( )

Definition at line 180 of file ShardingPropagation.cpp.

◆ createShardingPropagation() [2/2]

std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation ( ShardingPropagationOptions options)

◆ detectMoveLastSplitAxisInResharding()

std::optional< std::tuple< int64_t, int64_t, GridAxis > > mlir::shard::detectMoveLastSplitAxisInResharding ( Sharding sourceSharding,
Sharding targetSharding )
static

Definition at line 234 of file Partition.cpp.

References mlir::shard::Sharding::getSplitAxes().

Referenced by tryMoveLastSplitAxisInResharding().

◆ detectSplitLastAxisInResharding()

std::optional< std::tuple< int64_t, GridAxis > > mlir::shard::detectSplitLastAxisInResharding ( Sharding sourceSharding,
Sharding targetSharding )
static

Definition at line 87 of file Partition.cpp.

References mlir::shard::Sharding::getSplitAxes().

Referenced by trySplitLastAxisInResharding().

◆ detectUnsplitLastAxisInResharding()

std::optional< std::tuple< int64_t, GridAxis > > mlir::shard::detectUnsplitLastAxisInResharding ( Sharding sourceSharding,
Sharding targetSharding )
static

Definition at line 136 of file Partition.cpp.

References mlir::shard::Sharding::getSplitAxes().

Referenced by tryUnsplitLastAxisInResharding().

◆ gatherDimension()

int64_t mlir::shard::gatherDimension ( int64_t dimSize,
int64_t shardCount )
inline

◆ getGrid() [1/2]

template<typename Op>
shard::GridOp mlir::shard::getGrid ( Op op,
SymbolTableCollection & symbolTableCollection )

Definition at line 130 of file ShardOps.h.

References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().

◆ getGrid() [2/2]

shard::GridOp mlir::shard::getGrid ( Operation * op,
FlatSymbolRefAttr gridSymbol,
SymbolTableCollection & symbolTableCollection )
inline

◆ getGrid< ShardOp >()

template<>
shard::GridOp mlir::shard::getGrid< ShardOp > ( ShardOp op,
SymbolTableCollection & symbolTableCollection )
inline

Definition at line 135 of file ShardOps.h.

References getGrid().

◆ getGridAxisAssignmentForLoopIterators()

ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators ( ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
ArrayRef< utils::IteratorType > loopIteratorTypes,
ArrayRef< AffineMap > indexingMaps )

Definition at line 572 of file ShardingInterface.cpp.

References updateGridAxisAssignmentForLoopIterators().

◆ getGridOrNull()

shard::GridOp mlir::shard::getGridOrNull ( Operation * op,
FlatSymbolRefAttr gridSymbol,
SymbolTableCollection & symbolTableCollection )
inline

◆ getMixedAsValues()

SmallVector< Value > mlir::shard::getMixedAsValues ( OpBuilder b,
const Location & loc,
llvm::ArrayRef< int64_t > statics,
ValueRange dynamics,
Type type = Type() )

Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.

Definition at line 77 of file ShardOps.cpp.

References b.

◆ getOperandShardings()

std::vector< Sharding > mlir::shard::getOperandShardings ( Operation & op)
static

◆ getReductionGridAxes()

SmallVector< GridAxis > mlir::shard::getReductionGridAxes ( ArrayRef< utils::IteratorType > loopIteratorTypes,
ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators )

◆ getResultShardings()

std::vector< Sharding > mlir::shard::getResultShardings ( Operation & op)
static

◆ getSharding() [1/2]

FailureOr< std::pair< bool, Sharding > > mlir::shard::getSharding ( OpOperand & opOperand)

◆ getSharding() [2/2]

FailureOr< std::pair< bool, Sharding > > mlir::shard::getSharding ( OpResult result)

◆ isAtLeastOneReductionIteratorSharded()

bool mlir::shard::isAtLeastOneReductionIteratorSharded ( ArrayRef< utils::IteratorType > loopIteratorTypes,
ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators )

Definition at line 612 of file ShardingInterface.cpp.

◆ isFullReplication()

bool mlir::shard::isFullReplication ( Sharding sharding)
inline

◆ isReductionLoop()

bool mlir::shard::isReductionLoop ( utils::IteratorType iType)
inline

Definition at line 94 of file ShardOps.h.

◆ maybeInsertSourceShardingAnnotation()

◆ maybeInsertTargetShardingAnnotation()

void mlir::shard::maybeInsertTargetShardingAnnotation ( Sharding sharding,
OpResult result,
OpBuilder & builder )

Definition at line 338 of file ShardOps.cpp.

References maybeInsertTargetShardingAnnotationImpl(), and result.

Referenced by addShardOp().

◆ moveLastSplitAxisInResharding()

std::tuple< TypedValue< ShapedType >, Sharding > mlir::shard::moveLastSplitAxisInResharding ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue< ShapedType > sourceShard,
int64_t sourceTensorAxis,
int64_t targetTensorAxis,
GridAxis gridAxis )
static

◆ partitionBlock()

◆ partitionFullyReplicatedOperation()

void mlir::shard::partitionFullyReplicatedOperation ( Operation & op,
ArrayRef< Value > partitionedOperands,
ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
IRMapping & partitionMap,
SymbolTableCollection & symbolTable,
OpBuilder & builder )

◆ partitionFuncOp()

LogicalResult mlir::shard::partitionFuncOp ( FunctionOpInterface op,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection )
static

◆ partitionOperation() [1/3]

LogicalResult mlir::shard::partitionOperation ( Operation & op,
ArrayRef< Value > partitionedOperands,
ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection,
OpBuilder & builder )
static

◆ partitionOperation() [2/3]

LogicalResult mlir::shard::partitionOperation ( Operation & op,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection,
OpBuilder & builder )
static

◆ partitionOperation() [3/3]

LogicalResult mlir::shard::partitionOperation ( ShardOp shardOp,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection,
OpBuilder & builder )
static

◆ partitionTriviallyShardableOperation()

void mlir::shard::partitionTriviallyShardableOperation ( Operation & op,
ArrayRef< Value > partitionedOperands,
ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
IRMapping & partitionMap,
SymbolTableCollection & symbolTable,
OpBuilder & builder )

◆ populateAllOpLoweringPatterns()

void mlir::shard::populateAllOpLoweringPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

◆ populateAllReduceEndomorphismSimplificationPatterns()

template<typename AlgebraicOp>
void mlir::shard::populateAllReduceEndomorphismSimplificationPatterns ( RewritePatternSet & patterns,
ReductionKind reduction )

Definition at line 40 of file Simplifications.h.

References mlir::patterns.

Referenced by populateSimplificationPatterns().

◆ populateAllSliceOpLoweringPatterns()

void mlir::shard::populateAllSliceOpLoweringPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

Definition at line 178 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateFoldingPatterns()

void mlir::shard::populateFoldingPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

Definition at line 114 of file Simplifications.cpp.

References mlir::patterns.

Referenced by populateSimplificationPatterns().

◆ populateProcessMultiIndexOpLoweringPatterns()

void mlir::shard::populateProcessMultiIndexOpLoweringPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

Definition at line 168 of file Transforms.cpp.

References mlir::patterns.

Referenced by populateAllOpLoweringPatterns().

◆ populateSimplificationPatterns()

void mlir::shard::populateSimplificationPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

◆ registerAllOpLoweringDialects()

void mlir::shard::registerAllOpLoweringDialects ( DialectRegistry & registry)

◆ registerAllSliceOpLoweringDialects()

void mlir::shard::registerAllSliceOpLoweringDialects ( DialectRegistry & registry)

Definition at line 184 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerPartition()

void mlir::shard::registerPartition ( )
inline

Definition at line 200 of file Passes.h.

◆ registerPartitionPass()

void mlir::shard::registerPartitionPass ( )
inline

Definition at line 207 of file Passes.h.

◆ registerProcessMultiIndexOpLoweringDialects()

void mlir::shard::registerProcessMultiIndexOpLoweringDialects ( DialectRegistry & registry)

Definition at line 174 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerShardingPropagation()

void mlir::shard::registerShardingPropagation ( )
inline

Definition at line 221 of file Passes.h.

◆ registerShardingPropagationPass()

void mlir::shard::registerShardingPropagationPass ( )
inline

Definition at line 228 of file Passes.h.

◆ registerShardPasses()

void mlir::shard::registerShardPasses ( )
inline

Definition at line 242 of file Passes.h.

Referenced by mlir::registerAllPasses().

◆ removeTrailingEmptySubArray()

template<typename T>
void mlir::shard::removeTrailingEmptySubArray ( SmallVector< SmallVector< T > > & array)

◆ reshard() [1/3]

TypedValue< ShapedType > mlir::shard::reshard ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
Sharding targetSharding,
TypedValue< ShapedType > sourceUnshardedValue,
TypedValue< ShapedType > sourceShard )
static

◆ reshard() [2/3]

TypedValue< ShapedType > mlir::shard::reshard ( OpBuilder & builder,
GridOp grid,
ShardOp source,
ShardOp target,
TypedValue< ShapedType > sourceShardValue )

Definition at line 504 of file Partition.cpp.

References reshard(), and target.

Referenced by partitionOperation(), reshard(), and reshard().

◆ reshard() [3/3]

TypedValue< ShapedType > mlir::shard::reshard ( OpBuilder & builder,
ShardOp source,
ShardOp target,
TypedValue< ShapedType > sourceShardValue,
SymbolTableCollection & symbolTableCollection )

Definition at line 515 of file Partition.cpp.

References getGrid(), reshard(), and target.

◆ reshardingRegisterDependentDialects()

void mlir::shard::reshardingRegisterDependentDialects ( DialectRegistry & registry)

Definition at line 524 of file Partition.cpp.

References mlir::DialectRegistry::insert().

◆ reshardOn1DGrid()

TypedValue< ShapedType > mlir::shard::reshardOn1DGrid ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
Sharding targetSharding,
TypedValue< ShapedType > sourceUnshardedValue,
TypedValue< ShapedType > sourceShard )
static

◆ shardDimension()

int64_t mlir::shard::shardDimension ( int64_t dimSize,
int64_t shardCount )
inline

Definition at line 168 of file ShardOps.h.

Referenced by allToAllResultShapeInMoveLastAxis().

◆ shardedBlockArgumentTypes()

SmallVector< Type > mlir::shard::shardedBlockArgumentTypes ( Block & block,
SymbolTableCollection & symbolTableCollection )
static

◆ shardShapedType()

◆ shardType()

Type mlir::shard::shardType ( Type type,
GridOp grid,
Sharding sharding )

Definition at line 291 of file ShardOps.cpp.

References shardShapedType().

Referenced by partitionTriviallyShardableOperation().

◆ splitLastAxisInResharding()

std::tuple< TypedValue< ShapedType >, Sharding > mlir::shard::splitLastAxisInResharding ( ImplicitLocOpBuilder & builder,
Sharding sourceSharding,
TypedValue< ShapedType > sourceShard,
GridOp grid,
int64_t splitTensorAxis,
GridAxis splitGridAxis )
static

Definition at line 68 of file Partition.cpp.

Referenced by trySplitLastAxisInResharding().

◆ targetShardingInMoveLastAxis()

Sharding mlir::shard::targetShardingInMoveLastAxis ( MLIRContext * ctx,
Sharding sourceSharding,
int64_t sourceTensorAxis,
int64_t targetTensorAxis )
static

◆ targetShardingInSplitLastAxis()

Sharding mlir::shard::targetShardingInSplitLastAxis ( MLIRContext * ctx,
Sharding sourceSharding,
int64_t splitTensorAxis,
GridAxis splitGridAxis )
static

Definition at line 46 of file Partition.cpp.

◆ targetShardingInUnsplitLastAxis()

Sharding mlir::shard::targetShardingInUnsplitLastAxis ( MLIRContext * ctx,
Sharding sourceSharding,
int64_t splitTensorAxis )
static

◆ tryMoveLastSplitAxisInResharding()

std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::tryMoveLastSplitAxisInResharding ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue< ShapedType > sourceShard )
static

◆ trySplitLastAxisInResharding()

std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::trySplitLastAxisInResharding ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
Sharding targetSharding,
TypedValue< ShapedType > sourceShard )
static

Definition at line 119 of file Partition.cpp.

References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().

Referenced by reshardOn1DGrid().

◆ tryUnsplitLastAxisInResharding()

std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::tryUnsplitLastAxisInResharding ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue< ShapedType > sourceShard )
static

Definition at line 213 of file Partition.cpp.

References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().

Referenced by reshardOn1DGrid().

◆ tryUpdateHaloInResharding()

std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > mlir::shard::tryUpdateHaloInResharding ( ImplicitLocOpBuilder & builder,
GridOp grid,
Sharding sourceSharding,
Sharding targetSharding,
ShapedType sourceUnshardedShape,
TypedValue< ShapedType > sourceShard )
static

◆ unsplitLastAxisInResharding()

std::tuple< TypedValue< ShapedType >, Sharding > mlir::shard::unsplitLastAxisInResharding ( ImplicitLocOpBuilder & builder,
Sharding sourceSharding,
ShapedType sourceUnshardedShape,
TypedValue< ShapedType > sourceShard,
GridOp grid,
int64_t splitTensorAxis,
GridAxis splitGridAxis )
static