MLIR 23.0.0git
mlir::shard Namespace Reference

Namespaces

namespace  detail
namespace  impl

Classes

struct  ElementwiseShardingInterface
struct  IndependentParallelIteratorDomainShardingInterface
class  MoveSplitAxisPattern
 Move a split axis between tensor dimensions: e.g. More...
struct  OpRewritePatternWithSymbolTableCollection
class  ReshardingPattern
 Base class for resharding patterns. More...
class  Sharding
struct  ShardingOption
struct  ShardingPropagationOptions
class  SplitLastAxisPattern
 Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]]. More...
class  UnsplitLastAxesPattern
 Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> []. More...
class  UpdateHaloPattern
 Update halo sizes: handles cases where only the halo sizes differ between source and target sharding. More...

Typedefs

using ShardingArray = SmallVector<SmallVector<GridAxis>>
using ShardingArrayRef = ArrayRef<SmallVector<GridAxis>>
using GridAxis = int16_t
using GridAxesAttr = DenseI16ArrayAttr
using ShardShapeAttr = DenseI64ArrayAttr
using HaloSizePairAttr = DenseI64ArrayAttr
using UnshardedToShardedValueMap = DenseMap<Value, Value>

Enumerations

enum class  TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward }
 This enum controls the traversal order for the sharding propagation. More...

Functions

FailureOr< std::pair< bool, Sharding > > getSharding (OpResult result)
FailureOr< std::pair< bool, Sharding > > getSharding (OpOperand &opOperand)
void partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
ShardingArray getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
SmallVector< GridAxisgetReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
void partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
llvm::raw_ostream & operator<< (llvm::raw_ostream &os, const Sharding &sharding)
Diagnosticoperator<< (Diagnostic &diag, const Sharding &sharding)
bool isReductionLoop (utils::IteratorType iType)
template<typename T>
void removeTrailingEmptySubArray (SmallVector< SmallVector< T > > &array)
bool isFullReplication (Sharding sharding)
shard::GridOp getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
shard::GridOp getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
template<typename Op>
shard::GridOp getGrid (Op op, SymbolTableCollection &symbolTableCollection)
template<>
shard::GridOp getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection)
template<typename GridAxesRange, typename GridShapeRange>
int64_t collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape)
template<typename GridAxesRange>
int64_t collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid)
int64_t shardDimension (int64_t dimSize, int64_t shardCount)
int64_t gatherDimension (int64_t dimSize, int64_t shardCount)
ShapedType shardShapedType (ShapedType shape, GridOp grid, Sharding sharding)
Type shardType (Type type, GridOp grid, Sharding sharding)
void maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder)
void maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder)
SmallVector< ValuegetMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type())
 Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
TypedValue< ShapedType > reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
TypedValue< ShapedType > reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection)
void reshardingRegisterDependentDialects (DialectRegistry &registry)
std::unique_ptr<::mlir::PasscreatePartition ()
std::unique_ptr<::mlir::PasscreateShardSimplify ()
std::unique_ptr<::mlir::PasscreateShardingPropagation ()
std::unique_ptr<::mlir::PasscreateShardingPropagation (ShardingPropagationOptions options)
void registerPartition ()
void registerPartitionPass ()
void registerShardSimplify ()
void registerShardSimplifyPass ()
void registerShardingPropagation ()
void registerShardingPropagationPass ()
void registerShardPasses ()
template<typename AlgebraicOp>
void populateAllReduceEndomorphismSimplifyPatterns (RewritePatternSet &patterns, ReductionKind reduction)
void populateSimplifyPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerProcessMultiIndexOpLoweringDialects (DialectRegistry &registry)
void populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllSliceOpLoweringDialects (DialectRegistry &registry)
void populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection)
void registerAllOpLoweringDialects (DialectRegistry &registry)
TypedValue< IndexType > createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder)
TypedValue< IndexType > createProcessLinearIndex (ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={})
TypedValue< IndexType > createProcessLinearIndex (ImplicitLocOpBuilder &builder, StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes={})
template<typename SourceAxes, typename TargetAxes>
static bool arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
static TypedValue< ShapedType > reshard (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding, const Sharding &tgtSharding, TypedValue< ShapedType > unshardedSrc, TypedValue< ShapedType > shardedSrc)
static SmallVector< TypeshardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< ShardinggetOperandShardings (Operation &op)
static std::vector< ShardinggetResultShardings (Operation &op)
static LogicalResult partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult checkFullyAnnotated (Block &block)
static LogicalResult checkFullyAnnotated (Operation *op)
static LogicalResult partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)

Typedef Documentation

◆ GridAxesAttr

Definition at line 28 of file ShardOps.h.

◆ GridAxis

using mlir::shard::GridAxis = int16_t

Definition at line 27 of file ShardOps.h.

◆ HaloSizePairAttr

Definition at line 30 of file ShardOps.h.

◆ ShardingArray

◆ ShardingArrayRef

◆ ShardShapeAttr

Definition at line 29 of file ShardOps.h.

◆ UnshardedToShardedValueMap

Enumeration Type Documentation

◆ TraversalOrder

enum class mlir::shard::TraversalOrder
strong

This enum controls the traversal order for the sharding propagation.

Enumerator
Forward 

Forward traversal.

Backward 

Backward traversal.

ForwardBackward 

Forward then backward traversal.

BackwardForward 

Backward then forward traversal.

Definition at line 23 of file Passes.h.

Function Documentation

◆ arePartialAxesCompatible()

template<typename SourceAxes, typename TargetAxes>
bool mlir::shard::arePartialAxesCompatible ( const SourceAxes & sourceAxes,
const TargetAxes & targetAxes )
static

Definition at line 42 of file Partition.cpp.

◆ checkFullyAnnotated() [1/2]

◆ checkFullyAnnotated() [2/2]

LogicalResult mlir::shard::checkFullyAnnotated ( Operation * op)
static

◆ collectiveProcessGroupSize() [1/2]

template<typename GridAxesRange>
int64_t mlir::shard::collectiveProcessGroupSize ( GridAxesRange && gridAxes,
GridOp grid )

Definition at line 172 of file ShardOps.h.

References collectiveProcessGroupSize().

◆ collectiveProcessGroupSize() [2/2]

template<typename GridAxesRange, typename GridShapeRange>
int64_t mlir::shard::collectiveProcessGroupSize ( GridAxesRange && gridAxes,
GridShapeRange && gridShape )

◆ createCollectiveProcessGroupSize()

TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize ( GridOp grid,
ArrayRef< GridAxis > axes,
ImplicitLocOpBuilder & builder )

◆ createPartition()

std::unique_ptr<::mlir::Pass > mlir::shard::createPartition ( )

We declare an explicit private instantiation because Pass classes should only be visible by the current library.

Definition at line 81 of file Partition.cpp.

◆ createProcessLinearIndex() [1/2]

TypedValue< IndexType > mlir::shard::createProcessLinearIndex ( ImplicitLocOpBuilder & builder,
StringRef grid,
ArrayRef< GridAxis > gridAxes = {} )

◆ createProcessLinearIndex() [2/2]

TypedValue< IndexType > mlir::shard::createProcessLinearIndex ( ImplicitLocOpBuilder & builder,
StringRef grid,
ValueRange processInGroupMultiIndex,
ArrayRef< GridAxis > gridAxes = {} )

Definition at line 211 of file Transforms.cpp.

References mlir::arith::ConstantIndexOp::create().

◆ createShardingPropagation() [1/2]

std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation ( )

Definition at line 257 of file ShardingPropagation.cpp.

◆ createShardingPropagation() [2/2]

std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation ( ShardingPropagationOptions options)

Definition at line 261 of file ShardingPropagation.cpp.

References b.

◆ createShardSimplify()

std::unique_ptr<::mlir::Pass > mlir::shard::createShardSimplify ( )

We declare an explicit private instantiation because Pass classes should only be visible by the current library.

Definition at line 157 of file Simplify.cpp.

◆ gatherDimension()

int64_t mlir::shard::gatherDimension ( int64_t dimSize,
int64_t shardCount )
inline

Definition at line 187 of file ShardOps.h.

◆ getGrid() [1/2]

template<typename Op>
shard::GridOp mlir::shard::getGrid ( Op op,
SymbolTableCollection & symbolTableCollection )

Definition at line 140 of file ShardOps.h.

References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().

◆ getGrid() [2/2]

shard::GridOp mlir::shard::getGrid ( Operation * op,
FlatSymbolRefAttr gridSymbol,
SymbolTableCollection & symbolTableCollection )
inline

◆ getGrid< ShardOp >()

template<>
shard::GridOp mlir::shard::getGrid< ShardOp > ( ShardOp op,
SymbolTableCollection & symbolTableCollection )
inline

Definition at line 145 of file ShardOps.h.

References getGrid().

◆ getGridAxisAssignmentForLoopIterators()

ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators ( ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
ArrayRef< utils::IteratorType > loopIteratorTypes,
ArrayRef< AffineMap > indexingMaps )

Definition at line 572 of file ShardingInterface.cpp.

References updateGridAxisAssignmentForLoopIterators().

◆ getGridOrNull()

shard::GridOp mlir::shard::getGridOrNull ( Operation * op,
FlatSymbolRefAttr gridSymbol,
SymbolTableCollection & symbolTableCollection )
inline

◆ getMixedAsValues()

SmallVector< Value > mlir::shard::getMixedAsValues ( OpBuilder b,
const Location & loc,
llvm::ArrayRef< int64_t > statics,
ValueRange dynamics,
Type type = Type() )

Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.

Definition at line 77 of file ShardOps.cpp.

References b.

◆ getOperandShardings()

std::vector< Sharding > mlir::shard::getOperandShardings ( Operation & op)
static

◆ getReductionGridAxes()

SmallVector< GridAxis > mlir::shard::getReductionGridAxes ( ArrayRef< utils::IteratorType > loopIteratorTypes,
ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators )

◆ getResultShardings()

std::vector< Sharding > mlir::shard::getResultShardings ( Operation & op)
static

◆ getSharding() [1/2]

FailureOr< std::pair< bool, Sharding > > mlir::shard::getSharding ( OpOperand & opOperand)

◆ getSharding() [2/2]

FailureOr< std::pair< bool, Sharding > > mlir::shard::getSharding ( OpResult result)

◆ isAtLeastOneReductionIteratorSharded()

bool mlir::shard::isAtLeastOneReductionIteratorSharded ( ArrayRef< utils::IteratorType > loopIteratorTypes,
ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators )

Definition at line 612 of file ShardingInterface.cpp.

◆ isFullReplication()

bool mlir::shard::isFullReplication ( Sharding sharding)
inline

◆ isReductionLoop()

bool mlir::shard::isReductionLoop ( utils::IteratorType iType)
inline

Definition at line 104 of file ShardOps.h.

◆ maybeInsertSourceShardingAnnotation()

◆ maybeInsertTargetShardingAnnotation()

void mlir::shard::maybeInsertTargetShardingAnnotation ( Sharding sharding,
OpResult result,
OpBuilder & builder )

Definition at line 338 of file ShardOps.cpp.

References maybeInsertTargetShardingAnnotationImpl(), and result.

Referenced by addShardOp().

◆ operator<<() [1/2]

Diagnostic & mlir::shard::operator<< ( Diagnostic & diag,
const Sharding & sharding )
inline

Definition at line 85 of file ShardOps.h.

References diag().

◆ operator<<() [2/2]

llvm::raw_ostream & mlir::shard::operator<< ( llvm::raw_ostream & os,
const Sharding & sharding )

◆ partitionBlock()

◆ partitionFullyReplicatedOperation()

void mlir::shard::partitionFullyReplicatedOperation ( Operation & op,
ArrayRef< Value > partitionedOperands,
ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
IRMapping & partitionMap,
SymbolTableCollection & symbolTable,
OpBuilder & builder )

◆ partitionFuncOp()

LogicalResult mlir::shard::partitionFuncOp ( FunctionOpInterface op,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection )
static

◆ partitionOperation() [1/3]

LogicalResult mlir::shard::partitionOperation ( Operation & op,
ArrayRef< Value > partitionedOperands,
ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection,
OpBuilder & builder )
static

◆ partitionOperation() [2/3]

◆ partitionOperation() [3/3]

LogicalResult mlir::shard::partitionOperation ( ShardOp shardOp,
IRMapping & partitionMap,
SymbolTableCollection & symbolTableCollection,
OpBuilder & builder )
static

◆ partitionTriviallyShardableOperation()

void mlir::shard::partitionTriviallyShardableOperation ( Operation & op,
ArrayRef< Value > partitionedOperands,
ArrayRef< Sharding > operandShardings,
ArrayRef< Sharding > resultShardings,
IRMapping & partitionMap,
SymbolTableCollection & symbolTable,
OpBuilder & builder )

◆ populateAllOpLoweringPatterns()

void mlir::shard::populateAllOpLoweringPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

◆ populateAllReduceEndomorphismSimplifyPatterns()

template<typename AlgebraicOp>
void mlir::shard::populateAllReduceEndomorphismSimplifyPatterns ( RewritePatternSet & patterns,
ReductionKind reduction )

◆ populateAllSliceOpLoweringPatterns()

void mlir::shard::populateAllSliceOpLoweringPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

◆ populateFoldingPatterns()

void mlir::shard::populateFoldingPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

◆ populateProcessMultiIndexOpLoweringPatterns()

void mlir::shard::populateProcessMultiIndexOpLoweringPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

◆ populateSimplifyPatterns()

void mlir::shard::populateSimplifyPatterns ( RewritePatternSet & patterns,
SymbolTableCollection & symbolTableCollection )

Definition at line 136 of file Simplify.cpp.

◆ registerAllOpLoweringDialects()

void mlir::shard::registerAllOpLoweringDialects ( DialectRegistry & registry)

◆ registerAllSliceOpLoweringDialects()

void mlir::shard::registerAllSliceOpLoweringDialects ( DialectRegistry & registry)

Definition at line 183 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerPartition()

void mlir::shard::registerPartition ( )
inline

Definition at line 278 of file Passes.h.

◆ registerPartitionPass()

void mlir::shard::registerPartitionPass ( )
inline

Definition at line 285 of file Passes.h.

◆ registerProcessMultiIndexOpLoweringDialects()

void mlir::shard::registerProcessMultiIndexOpLoweringDialects ( DialectRegistry & registry)

Definition at line 173 of file Transforms.cpp.

References mlir::DialectRegistry::insert().

Referenced by registerAllOpLoweringDialects().

◆ registerShardingPropagation()

void mlir::shard::registerShardingPropagation ( )
inline

Definition at line 320 of file Passes.h.

◆ registerShardingPropagationPass()

void mlir::shard::registerShardingPropagationPass ( )
inline

Definition at line 327 of file Passes.h.

◆ registerShardPasses()

void mlir::shard::registerShardPasses ( )
inline

Definition at line 341 of file Passes.h.

Referenced by mlir::registerAllPasses().

◆ registerShardSimplify()

void mlir::shard::registerShardSimplify ( )
inline

Definition at line 299 of file Passes.h.

◆ registerShardSimplifyPass()

void mlir::shard::registerShardSimplifyPass ( )
inline

Definition at line 306 of file Passes.h.

◆ removeTrailingEmptySubArray()

template<typename T>
void mlir::shard::removeTrailingEmptySubArray ( SmallVector< SmallVector< T > > & array)

◆ reshard() [1/3]

TypedValue< ShapedType > mlir::shard::reshard ( ImplicitLocOpBuilder & builder,
GridOp grid,
const Sharding & srcSharding,
const Sharding & tgtSharding,
TypedValue< ShapedType > unshardedSrc,
TypedValue< ShapedType > shardedSrc )
static

◆ reshard() [2/3]

TypedValue< ShapedType > mlir::shard::reshard ( OpBuilder & builder,
GridOp grid,
ShardOp source,
ShardOp target,
TypedValue< ShapedType > sourceShardValue )

Definition at line 495 of file Partition.cpp.

References reshard().

Referenced by partitionOperation(), reshard(), and reshard().

◆ reshard() [3/3]

TypedValue< ShapedType > mlir::shard::reshard ( OpBuilder & builder,
ShardOp source,
ShardOp target,
TypedValue< ShapedType > sourceShardValue,
SymbolTableCollection & symbolTableCollection )

Definition at line 506 of file Partition.cpp.

References getGrid(), and reshard().

◆ reshardingRegisterDependentDialects()

void mlir::shard::reshardingRegisterDependentDialects ( DialectRegistry & registry)

Definition at line 515 of file Partition.cpp.

References mlir::DialectRegistry::insert().

◆ shardDimension()

int64_t mlir::shard::shardDimension ( int64_t dimSize,
int64_t shardCount )
inline

Definition at line 178 of file ShardOps.h.

◆ shardedBlockArgumentTypes()

SmallVector< Type > mlir::shard::shardedBlockArgumentTypes ( Block & block,
SymbolTableCollection & symbolTableCollection )
static

◆ shardShapedType()

ShapedType mlir::shard::shardShapedType ( ShapedType shape,
GridOp grid,
Sharding sharding )

◆ shardType()

Type mlir::shard::shardType ( Type type,
GridOp grid,
Sharding sharding )

Definition at line 291 of file ShardOps.cpp.

References shardShapedType().

Referenced by partitionTriviallyShardableOperation().