MLIR
22.0.0git
|
Namespaces | |
detail | |
Classes | |
struct | ShardingOption |
struct | IndependentParallelIteratorDomainShardingInterface |
struct | ElementwiseShardingInterface |
class | Sharding |
struct | OpRewritePatternWithSymbolTableCollection |
Typedefs | |
using | ShardingArray = SmallVector< SmallVector< GridAxis > > |
using | ShardingArrayRef = ArrayRef< SmallVector< GridAxis > > |
using | GridAxis = int16_t |
using | GridAxesAttr = DenseI16ArrayAttr |
using | ShardShapeAttr = DenseI64ArrayAttr |
using | HaloSizePairAttr = DenseI64ArrayAttr |
using | UnshardedToShardedValueMap = DenseMap< Value, Value > |
Enumerations | |
enum class | TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward } |
This enum controls the traversal order for the sharding propagation. More... | |
Functions | |
FailureOr< std::pair< bool, Sharding > > | getSharding (OpResult result) |
FailureOr< std::pair< bool, Sharding > > | getSharding (OpOperand &opOperand) |
void | partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
ShardingArray | getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators) |
SmallVector< GridAxis > | getReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators) |
void | partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
bool | isReductionLoop (utils::IteratorType iType) |
template<typename T > | |
void | removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array) |
bool | isFullReplication (Sharding sharding) |
shard::GridOp | getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
shard::GridOp | getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
template<typename Op > | |
shard::GridOp | getGrid (Op op, SymbolTableCollection &symbolTableCollection) |
template<> | |
shard::GridOp | getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
template<typename GridAxesRange , typename GridShapeRange > | |
int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape) |
template<typename GridAxesRange > | |
int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid) |
int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
ShapedType | shardShapedType (ShapedType shape, GridOp grid, Sharding sharding) |
Type | shardType (Type type, GridOp grid, Sharding sharding) |
void | maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder) |
void | maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder) |
SmallVector< Value > | getMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type()) |
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. More... | |
TypedValue< ShapedType > | reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
template<typename AlgebraicOp > | |
void | populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
void | populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
TypedValue< IndexType > | createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder) |
TypedValue< IndexType > | createProcessLinearIndex (StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder) |
TypedValue< IndexType > | createProcessLinearIndex (StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder) |
template<typename SourceAxes , typename TargetAxes > | |
static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
static Sharding | targetShardingInSplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis) |
static std::tuple< TypedValue< ShapedType >, Sharding > | splitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
static std::optional< std::tuple< int64_t, GridAxis > > | detectSplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< int64_t, GridAxis > > | detectUnsplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
static Sharding | targetShardingInUnsplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis) |
static ShapedType | allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) |
static std::tuple< TypedValue< ShapedType >, Sharding > | unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > | detectMoveLastSplitAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
static Sharding | targetShardingInMoveLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
static ShapedType | allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
static std::tuple< TypedValue< ShapedType >, Sharding > | moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis) |
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static TypedValue< ShapedType > | reshardOn1DGrid (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
static TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
static SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
static LogicalResult | partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static std::vector< Sharding > | getOperandShardings (Operation &op) |
static std::vector< Sharding > | getResultShardings (Operation &op) |
static LogicalResult | partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection) |
using mlir::shard::GridAxesAttr = typedef DenseI16ArrayAttr |
Definition at line 27 of file ShardOps.h.
using mlir::shard::GridAxis = typedef int16_t |
Definition at line 26 of file ShardOps.h.
using mlir::shard::HaloSizePairAttr = typedef DenseI64ArrayAttr |
Definition at line 29 of file ShardOps.h.
using mlir::shard::ShardingArray = typedef SmallVector<SmallVector<GridAxis> > |
Definition at line 25 of file ShardingInterface.h.
using mlir::shard::ShardingArrayRef = typedef ArrayRef<SmallVector<GridAxis> > |
Definition at line 26 of file ShardingInterface.h.
using mlir::shard::ShardShapeAttr = typedef DenseI64ArrayAttr |
Definition at line 28 of file ShardOps.h.
using mlir::shard::UnshardedToShardedValueMap = typedef DenseMap<Value, Value> |
Definition at line 533 of file Partition.cpp.
|
strong |
|
static |
Definition at line 180 of file Partition.cpp.
References gatherDimension().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 304 of file Partition.cpp.
References gatherDimension(), and shardDimension().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 39 of file Partition.cpp.
int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
GridOp | grid | ||
) |
Definition at line 162 of file ShardOps.h.
References collectiveProcessGroupSize().
int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
GridShapeRange && | gridShape | ||
) |
Definition at line 146 of file ShardOps.h.
Referenced by collectiveProcessGroupSize(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize | ( | GridOp | grid, |
ArrayRef< GridAxis > | axes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 202 of file Transforms.cpp.
References mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | StringRef | grid, |
ArrayRef< GridAxis > | gridAxes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 228 of file Transforms.cpp.
Referenced by mlir::linalg::createDestinationPassingStyleInitOperand().
TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | StringRef | grid, |
ValueRange | processInGroupMultiIndex, | ||
ArrayRef< GridAxis > | gridAxes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 212 of file Transforms.cpp.
References mlir::arith::ConstantIndexOp::create(), and mlir::affine::linearizeIndex().
|
static |
Definition at line 235 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 87 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 136 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryUnsplitLastAxisInResharding().
|
inline |
Definition at line 177 of file ShardOps.h.
Referenced by allGatherResultShapeInUnsplitLastAxis(), and allToAllResultShapeInMoveLastAxis().
shard::GridOp mlir::shard::getGrid | ( | Op | op, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 130 of file ShardOps.h.
References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().
|
inline |
Definition at line 121 of file ShardOps.h.
References getGridOrNull().
Referenced by getGrid(), mlir::linalg::getGrid(), getGrid< ShardOp >(), reshard(), and shardedBlockArgumentTypes().
|
inline |
Definition at line 135 of file ShardOps.h.
References getGrid().
ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators | ( | ArrayRef< Sharding > | operandShardings, |
ArrayRef< Sharding > | resultShardings, | ||
ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
ArrayRef< AffineMap > | indexingMaps | ||
) |
Definition at line 572 of file ShardingInterface.cpp.
References updateGridAxisAssignmentForLoopIterators().
|
inline |
Definition at line 113 of file ShardOps.h.
References mlir::SymbolTableCollection::lookupNearestSymbolFrom().
Referenced by getGrid(), getGridAndVerify(), and partitionTriviallyShardableOperation().
SmallVector< Value > mlir::shard::getMixedAsValues | ( | OpBuilder | b, |
const Location & | loc, | ||
llvm::ArrayRef< int64_t > | statics, | ||
ValueRange | dynamics, | ||
Type | type = Type() |
||
) |
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition at line 77 of file ShardOps.cpp.
References mlir::Builder::getI64IntegerAttr(), mlir::Builder::getI64Type(), mlir::Builder::getIndexAttr(), and mlir::Builder::getIndexType().
Definition at line 591 of file Partition.cpp.
References mlir::Value::getDefiningOp(), mlir::Operation::getNumOperands(), and mlir::Operation::getOperands().
Referenced by partitionOperation().
SmallVector< GridAxis > mlir::shard::getReductionGridAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
ArrayRef< SmallVector< GridAxis >> | gridAxisAssignmentForLoopIterators | ||
) |
Definition at line 625 of file ShardingInterface.cpp.
Referenced by mlir::linalg::partitionLinalgOpWithShardedReduction().
Definition at line 611 of file Partition.cpp.
References mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::getResults(), and mlir::Value::getUsers().
Referenced by partitionOperation().
Definition at line 152 of file ShardingInterface.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
Definition at line 109 of file ShardingInterface.cpp.
References mlir::Value::getUsers(), and mlir::Value::hasOneUse().
Referenced by addShardOp(), mlir::shard::detail::defaultGetShardingAnnotations(), and visitOp().
bool mlir::shard::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
ArrayRef< SmallVector< GridAxis >> | gridAxisAssignmentForLoopIterators | ||
) |
Definition at line 612 of file ShardingInterface.cpp.
|
inline |
Definition at line 106 of file ShardOps.h.
References mlir::shard::Sharding::getSplitAxes().
Referenced by isValueCompatibleWithFullReplicationSharding(), maybeInsertSourceShardingAnnotation(), and reshard().
|
inline |
Definition at line 94 of file ShardOps.h.
void mlir::shard::maybeInsertSourceShardingAnnotation | ( | Sharding | sharding, |
OpOperand & | operand, | ||
OpBuilder & | builder | ||
) |
Definition at line 352 of file ShardOps.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getDefiningOp(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Value::getType(), mlir::Operation::hasTrait(), isFullReplication(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by addShardOp().
void mlir::shard::maybeInsertTargetShardingAnnotation | ( | Sharding | sharding, |
OpResult | result, | ||
OpBuilder & | builder | ||
) |
Definition at line 338 of file ShardOps.cpp.
References mlir::Value::getUses(), and maybeInsertTargetShardingAnnotationImpl().
Referenced by addShardOp().
|
static |
Definition at line 317 of file Partition.cpp.
References allToAllResultShapeInMoveLastAxis(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInMoveLastAxis().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 703 of file Partition.cpp.
References mlir::OpBuilder::createBlock(), mlir::remark::failed(), mlir::Block::getArguments(), mlir::Block::getOperations(), mlir::Block::getParent(), mlir::IRMapping::map(), partitionOperation(), mlir::OpBuilder::setInsertionPointToEnd(), and shardedBlockArgumentTypes().
Referenced by partitionFuncOp().
void mlir::shard::partitionFullyReplicatedOperation | ( | Operation & | op, |
ArrayRef< Value > | partitionedOperands, | ||
ArrayRef< Sharding > | operandShardings, | ||
ArrayRef< Sharding > | resultShardings, | ||
IRMapping & | partitionMap, | ||
SymbolTableCollection & | symbolTable, | ||
OpBuilder & | builder | ||
) |
Definition at line 543 of file ShardingInterface.cpp.
References areValuesCompatibleWithFullReplicationShardings(), mlir::OpBuilder::clone(), mlir::Operation::getOperands(), and mlir::Operation::getResults().
Referenced by partitionOperation().
|
static |
Definition at line 731 of file Partition.cpp.
References mlir::remark::failed(), mlir::get(), mlir::Operation::getOperandTypes(), and partitionBlock().
|
static |
Definition at line 562 of file Partition.cpp.
References mlir::remark::failed(), mlir::Operation::getResults(), and partitionFullyReplicatedOperation().
|
static |
Definition at line 669 of file Partition.cpp.
References mlir::OpBuilder::clone(), mlir::Operation::emitError(), mlir::Operation::getOperands(), getOperandShardings(), mlir::Operation::getResult(), getResultShardings(), and mlir::IRMapping::map().
Referenced by partitionBlock().
|
static |
Definition at line 645 of file Partition.cpp.
References mlir::IRMapping::contains(), mlir::Value::getDefiningOp(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), and reshard().
void mlir::shard::partitionTriviallyShardableOperation | ( | Operation & | op, |
ArrayRef< Value > | partitionedOperands, | ||
ArrayRef< Sharding > | operandShardings, | ||
ArrayRef< Sharding > | resultShardings, | ||
IRMapping & | partitionMap, | ||
SymbolTableCollection & | symbolTable, | ||
OpBuilder & | builder | ||
) |
Definition at line 638 of file ShardingInterface.cpp.
References mlir::OpBuilder::clone(), getGridOrNull(), mlir::Operation::getResults(), and shardType().
Referenced by mlir::shard::IndependentParallelIteratorDomainShardingInterface< Op >::partition(), mlir::shard::ElementwiseShardingInterface< ElemwiseOp >::partition(), and mlir::linalg::partitionLinalgOpWithShardedReduction().
void mlir::shard::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 190 of file Transforms.cpp.
References mlir::patterns, populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
void mlir::shard::populateAllReduceEndomorphismSimplificationPatterns | ( | RewritePatternSet & | patterns, |
ReductionKind | reduction | ||
) |
Definition at line 40 of file Simplifications.h.
References mlir::patterns.
void mlir::shard::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 178 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
void mlir::shard::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 114 of file Simplifications.cpp.
References mlir::patterns.
Referenced by populateSimplificationPatterns().
void mlir::shard::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 168 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
void mlir::shard::populateSimplificationPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 23 of file Simplifications.cpp.
References mlir::patterns, and populateFoldingPatterns().
void mlir::shard::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 196 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
void mlir::shard::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 184 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
void mlir::shard::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 174 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
void mlir::shard::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T >> & | array | ) |
Definition at line 100 of file ShardOps.h.
Referenced by mlir::shard::detail::defaultGetShardingOption(), and getSharding().
|
static |
Definition at line 481 of file Partition.cpp.
References isFullReplication(), reshardOn1DGrid(), and tryUpdateHaloInResharding().
Referenced by partitionOperation().
TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
GridOp | grid, | ||
ShardOp | source, | ||
ShardOp | target, | ||
TypedValue< ShapedType > | sourceShardValue | ||
) |
Definition at line 505 of file Partition.cpp.
Referenced by reshard().
TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
ShardOp | source, | ||
ShardOp | target, | ||
TypedValue< ShapedType > | sourceShardValue, | ||
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 517 of file Partition.cpp.
void mlir::shard::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 526 of file Partition.cpp.
References mlir::DialectRegistry::insert().
|
static |
Definition at line 439 of file Partition.cpp.
References mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), shardShapedType(), tryMoveLastSplitAxisInResharding(), trySplitLastAxisInResharding(), and tryUnsplitLastAxisInResharding().
Referenced by reshard().
|
inline |
Definition at line 168 of file ShardOps.h.
Referenced by allToAllResultShapeInMoveLastAxis().
|
static |
Definition at line 539 of file Partition.cpp.
References mlir::Block::getArguments(), getGrid(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by partitionBlock().
ShapedType mlir::shard::shardShapedType | ( | ShapedType | shape, |
GridOp | grid, | ||
Sharding | sharding | ||
) |
Definition at line 281 of file ShardOps.cpp.
References mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), and shardShape().
Referenced by moveLastSplitAxisInResharding(), reshardOn1DGrid(), shardedBlockArgumentTypes(), shardType(), and unsplitLastAxisInResharding().
Definition at line 291 of file ShardOps.cpp.
References shardShapedType().
Referenced by partitionTriviallyShardableOperation().
|
static |
Definition at line 68 of file Partition.cpp.
References mlir::Builder::getContext(), and targetShardingInSplitLastAxis().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 276 of file Partition.cpp.
References mlir::shard::Sharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 46 of file Partition.cpp.
References mlir::shard::Sharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by splitLastAxisInResharding().
|
static |
Definition at line 164 of file Partition.cpp.
References mlir::shard::Sharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 345 of file Partition.cpp.
References detectMoveLastSplitAxisInResharding(), and moveLastSplitAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 119 of file Partition.cpp.
References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 214 of file Partition.cpp.
References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 366 of file Partition.cpp.
References mlir::shard::Sharding::equalHaloSizes(), mlir::shard::Sharding::equalSplitAxes(), mlir::get(), mlir::Builder::getContext(), mlir::shard::Sharding::getDynamicHaloSizes(), mlir::ImplicitLocOpBuilder::getLoc(), mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), and mlir::shard::Sharding::getStaticShardedDimsOffsets().
Referenced by reshard().
|
static |
Definition at line 188 of file Partition.cpp.
References allGatherResultShapeInUnsplitLastAxis(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInUnsplitLastAxis().
Referenced by tryUnsplitLastAxisInResharding().