|
MLIR 22.0.0git
|
Namespaces | |
| namespace | detail |
| namespace | impl |
Classes | |
| struct | ElementwiseShardingInterface |
| struct | IndependentParallelIteratorDomainShardingInterface |
| struct | OpRewritePatternWithSymbolTableCollection |
| class | Sharding |
| struct | ShardingOption |
| struct | ShardingPropagationOptions |
Typedefs | |
| using | ShardingArray = SmallVector<SmallVector<GridAxis>> |
| using | ShardingArrayRef = ArrayRef<SmallVector<GridAxis>> |
| using | GridAxis = int16_t |
| using | GridAxesAttr = DenseI16ArrayAttr |
| using | ShardShapeAttr = DenseI64ArrayAttr |
| using | HaloSizePairAttr = DenseI64ArrayAttr |
| using | UnshardedToShardedValueMap = DenseMap<Value, Value> |
Enumerations | |
| enum class | TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward } |
| This enum controls the traversal order for the sharding propagation. More... | |
Functions | |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpResult result) |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpOperand &opOperand) |
| void | partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| ShardingArray | getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
| bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators) |
| SmallVector< GridAxis > | getReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators) |
| void | partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| bool | isReductionLoop (utils::IteratorType iType) |
| template<typename T> | |
| void | removeTrailingEmptySubArray (SmallVector< SmallVector< T > > &array) |
| bool | isFullReplication (Sharding sharding) |
| shard::GridOp | getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| shard::GridOp | getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| template<typename Op> | |
| shard::GridOp | getGrid (Op op, SymbolTableCollection &symbolTableCollection) |
| template<> | |
| shard::GridOp | getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
| template<typename GridAxesRange, typename GridShapeRange> | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape) |
| template<typename GridAxesRange> | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid) |
| int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
| int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
| ShapedType | shardShapedType (ShapedType shape, GridOp grid, Sharding sharding) |
| Type | shardType (Type type, GridOp grid, Sharding sharding) |
| void | maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder) |
| void | maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder) |
| SmallVector< Value > | getMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type()) |
| Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. | |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
| void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
| std::unique_ptr<::mlir::Pass > | createPartition () |
| std::unique_ptr<::mlir::Pass > | createShardingPropagation () |
| std::unique_ptr<::mlir::Pass > | createShardingPropagation (ShardingPropagationOptions options) |
| void | registerPartition () |
| void | registerPartitionPass () |
| void | registerShardingPropagation () |
| void | registerShardingPropagationPass () |
| void | registerShardPasses () |
| template<typename AlgebraicOp> | |
| void | populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
| void | populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
| TypedValue< IndexType > | createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder) |
| TypedValue< IndexType > | createProcessLinearIndex (StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder) |
| TypedValue< IndexType > | createProcessLinearIndex (StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder) |
| template<typename SourceAxes, typename TargetAxes> | |
| static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
| static Sharding | targetShardingInSplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | splitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::optional< std::tuple< int64_t, GridAxis > > | detectSplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, GridAxis > > | detectUnsplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
| static Sharding | targetShardingInUnsplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis) |
| static ShapedType | allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, int64_t, GridAxis > > | detectMoveLastSplitAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
| static Sharding | targetShardingInMoveLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static ShapedType | allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static TypedValue< ShapedType > | reshardOn1DGrid (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
| static TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
| static SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
| static LogicalResult | partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static std::vector< Sharding > | getOperandShardings (Operation &op) |
| static std::vector< Sharding > | getResultShardings (Operation &op) |
| static LogicalResult | partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection) |
Definition at line 27 of file ShardOps.h.
| using mlir::shard::GridAxis = int16_t |
Definition at line 26 of file ShardOps.h.
Definition at line 29 of file ShardOps.h.
Definition at line 25 of file ShardingInterface.h.
Definition at line 26 of file ShardingInterface.h.
Definition at line 28 of file ShardOps.h.
Definition at line 531 of file Partition.cpp.
|
strong |
|
static |
Definition at line 180 of file Partition.cpp.
References gatherDimension().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 303 of file Partition.cpp.
References gatherDimension(), and shardDimension().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 39 of file Partition.cpp.
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridOp | grid ) |
Definition at line 162 of file ShardOps.h.
References collectiveProcessGroupSize().
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridShapeRange && | gridShape ) |
Definition at line 146 of file ShardOps.h.
Referenced by collectiveProcessGroupSize(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
| TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize | ( | GridOp | grid, |
| ArrayRef< GridAxis > | axes, | ||
| ImplicitLocOpBuilder & | builder ) |
Definition at line 202 of file Transforms.cpp.
References mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
| std::unique_ptr<::mlir::Pass > mlir::shard::createPartition | ( | ) |
We declare an explicit private instantiation because Pass classes should only be visible by the current library.
Definition at line 80 of file Partition.cpp.
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | StringRef | grid, |
| ArrayRef< GridAxis > | gridAxes, | ||
| ImplicitLocOpBuilder & | builder ) |
Definition at line 228 of file Transforms.cpp.
References createProcessLinearIndex().
Referenced by mlir::linalg::createDestinationPassingStyleInitOperand(), and createProcessLinearIndex().
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | StringRef | grid, |
| ValueRange | processInGroupMultiIndex, | ||
| ArrayRef< GridAxis > | gridAxes, | ||
| ImplicitLocOpBuilder & | builder ) |
Definition at line 212 of file Transforms.cpp.
References mlir::arith::ConstantIndexOp::create().
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation | ( | ) |
Definition at line 180 of file ShardingPropagation.cpp.
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation | ( | ShardingPropagationOptions | options | ) |
Definition at line 184 of file ShardingPropagation.cpp.
References NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS, RESHARDING_FOR_EXPLICIT_ANNOTATIONS, and result.
|
static |
Definition at line 234 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 87 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 136 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryUnsplitLastAxisInResharding().
Definition at line 177 of file ShardOps.h.
Referenced by allGatherResultShapeInUnsplitLastAxis(), and allToAllResultShapeInMoveLastAxis().
| shard::GridOp mlir::shard::getGrid | ( | Op | op, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 130 of file ShardOps.h.
References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().
|
inline |
Definition at line 121 of file ShardOps.h.
References getGridOrNull().
Referenced by mlir::linalg::getGrid(), getGrid(), getGrid< ShardOp >(), reshard(), and shardedBlockArgumentTypes().
|
inline |
Definition at line 135 of file ShardOps.h.
References getGrid().
| ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators | ( | ArrayRef< Sharding > | operandShardings, |
| ArrayRef< Sharding > | resultShardings, | ||
| ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
| ArrayRef< AffineMap > | indexingMaps ) |
Definition at line 572 of file ShardingInterface.cpp.
References updateGridAxisAssignmentForLoopIterators().
|
inline |
Definition at line 113 of file ShardOps.h.
References mlir::SymbolTableCollection::lookupNearestSymbolFrom().
Referenced by getGrid(), getGridAndVerify(), and partitionTriviallyShardableOperation().
| SmallVector< Value > mlir::shard::getMixedAsValues | ( | OpBuilder | b, |
| const Location & | loc, | ||
| llvm::ArrayRef< int64_t > | statics, | ||
| ValueRange | dynamics, | ||
| Type | type = Type() ) |
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition at line 77 of file ShardOps.cpp.
References b.
Definition at line 589 of file Partition.cpp.
References mlir::Value::getDefiningOp(), mlir::Operation::getNumOperands(), mlir::Operation::getOperands(), and getOperandShardings().
Referenced by getOperandShardings(), and partitionOperation().
| SmallVector< GridAxis > mlir::shard::getReductionGridAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis > > | gridAxisAssignmentForLoopIterators ) |
Definition at line 625 of file ShardingInterface.cpp.
Referenced by mlir::linalg::partitionLinalgOpWithShardedReduction().
Definition at line 609 of file Partition.cpp.
References mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::getResults(), getResultShardings(), if(), and result.
Referenced by getResultShardings(), and partitionOperation().
Definition at line 152 of file ShardingInterface.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
Definition at line 109 of file ShardingInterface.cpp.
References getSharding(), mlir::Value::getUsers(), mlir::Value::hasOneUse(), and result.
Referenced by addShardOp(), addShardOp(), mlir::shard::detail::defaultGetShardingAnnotations(), getSharding(), and visitOp().
| bool mlir::shard::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis > > | gridAxisAssignmentForLoopIterators ) |
Definition at line 612 of file ShardingInterface.cpp.
Definition at line 106 of file ShardOps.h.
References mlir::shard::Sharding::getSplitAxes().
Referenced by isValueCompatibleWithFullReplicationSharding(), maybeInsertSourceShardingAnnotation(), and reshard().
|
inline |
Definition at line 94 of file ShardOps.h.
| void mlir::shard::maybeInsertSourceShardingAnnotation | ( | Sharding | sharding, |
| OpOperand & | operand, | ||
| OpBuilder & | builder ) |
Definition at line 352 of file ShardOps.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getDefiningOp(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Value::getType(), mlir::Operation::hasTrait(), isFullReplication(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by addShardOp().
| void mlir::shard::maybeInsertTargetShardingAnnotation | ( | Sharding | sharding, |
| OpResult | result, | ||
| OpBuilder & | builder ) |
Definition at line 338 of file ShardOps.cpp.
References maybeInsertTargetShardingAnnotationImpl(), and result.
Referenced by addShardOp().
|
static |
Definition at line 316 of file Partition.cpp.
References allToAllResultShapeInMoveLastAxis(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInMoveLastAxis().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 701 of file Partition.cpp.
References mlir::OpBuilder::createBlock(), mlir::Block::getArguments(), mlir::Block::getOperations(), mlir::Block::getParent(), mlir::IRMapping::map(), partitionBlock(), partitionOperation(), mlir::OpBuilder::setInsertionPointToEnd(), shardedBlockArgumentTypes(), and success().
Referenced by partitionBlock(), and partitionFuncOp().
| void mlir::shard::partitionFullyReplicatedOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder ) |
Definition at line 543 of file ShardingInterface.cpp.
References areValuesCompatibleWithFullReplicationShardings(), mlir::OpBuilder::clone(), mlir::Operation::getOperands(), and mlir::Operation::getResults().
Referenced by partitionOperation().
|
static |
Definition at line 729 of file Partition.cpp.
References b, mlir::Operation::getOperandTypes(), partitionBlock(), partitionFuncOp(), and success().
Referenced by partitionFuncOp().
|
static |
Definition at line 560 of file Partition.cpp.
References mlir::Operation::getResults(), partitionFullyReplicatedOperation(), partitionOperation(), result, and success().
Referenced by partitionBlock(), partitionOperation(), partitionOperation(), and partitionOperation().
|
static |
Definition at line 667 of file Partition.cpp.
References mlir::OpBuilder::clone(), mlir::Operation::emitError(), mlir::Operation::getOperands(), getOperandShardings(), mlir::Operation::getResult(), getResultShardings(), mlir::IRMapping::map(), partitionOperation(), and success().
|
static |
Definition at line 643 of file Partition.cpp.
References mlir::IRMapping::contains(), mlir::Value::getDefiningOp(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), partitionOperation(), reshard(), and success().
| void mlir::shard::partitionTriviallyShardableOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder ) |
Definition at line 638 of file ShardingInterface.cpp.
References mlir::OpBuilder::clone(), getGridOrNull(), mlir::Operation::getResults(), and shardType().
Referenced by mlir::shard::ElementwiseShardingInterface< ElemwiseOp >::partition(), and mlir::shard::IndependentParallelIteratorDomainShardingInterface< Op >::partition().
| void mlir::shard::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 190 of file Transforms.cpp.
References mlir::patterns, populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
| void mlir::shard::populateAllReduceEndomorphismSimplificationPatterns | ( | RewritePatternSet & | patterns, |
| ReductionKind | reduction ) |
Definition at line 40 of file Simplifications.h.
References mlir::patterns.
Referenced by populateSimplificationPatterns().
| void mlir::shard::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 178 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 114 of file Simplifications.cpp.
References mlir::patterns.
Referenced by populateSimplificationPatterns().
| void mlir::shard::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 168 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateSimplificationPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 23 of file Simplifications.cpp.
References mlir::patterns, populateAllReduceEndomorphismSimplificationPatterns(), and populateFoldingPatterns().
| void mlir::shard::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 196 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
| void mlir::shard::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 184 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
| void mlir::shard::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 174 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
|
inline |
|
inline |
Definition at line 242 of file Passes.h.
Referenced by mlir::registerAllPasses().
| void mlir::shard::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T > > & | array | ) |
Definition at line 100 of file ShardOps.h.
Referenced by mlir::shard::detail::defaultGetShardingOption(), getSharding(), and getSharding().
|
static |
Definition at line 480 of file Partition.cpp.
References isFullReplication(), reshardOn1DGrid(), and tryUpdateHaloInResharding().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| GridOp | grid, | ||
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue ) |
Definition at line 504 of file Partition.cpp.
References reshard(), and target.
Referenced by partitionOperation(), reshard(), and reshard().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue, | ||
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 515 of file Partition.cpp.
| void mlir::shard::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 524 of file Partition.cpp.
References mlir::DialectRegistry::insert().
|
static |
Definition at line 438 of file Partition.cpp.
References mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), shardShapedType(), tryMoveLastSplitAxisInResharding(), trySplitLastAxisInResharding(), and tryUnsplitLastAxisInResharding().
Referenced by reshard().
Definition at line 168 of file ShardOps.h.
Referenced by allToAllResultShapeInMoveLastAxis().
|
static |
Definition at line 537 of file Partition.cpp.
References mlir::Block::getArguments(), getGrid(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by partitionBlock().
| ShapedType mlir::shard::shardShapedType | ( | ShapedType | shape, |
| GridOp | grid, | ||
| Sharding | sharding ) |
Definition at line 281 of file ShardOps.cpp.
References mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), and shardShape().
Referenced by moveLastSplitAxisInResharding(), reshardOn1DGrid(), shardedBlockArgumentTypes(), shardType(), and unsplitLastAxisInResharding().
Definition at line 291 of file ShardOps.cpp.
References shardShapedType().
Referenced by partitionTriviallyShardableOperation().
|
static |
Definition at line 68 of file Partition.cpp.
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 275 of file Partition.cpp.
References mlir::detail::DenseArrayAttrImpl< int16_t >::get(), mlir::shard::Sharding::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 46 of file Partition.cpp.
|
static |
Definition at line 164 of file Partition.cpp.
References mlir::detail::DenseArrayAttrImpl< int16_t >::get(), mlir::shard::Sharding::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 344 of file Partition.cpp.
References detectMoveLastSplitAxisInResharding(), and moveLastSplitAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 119 of file Partition.cpp.
References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 213 of file Partition.cpp.
References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 365 of file Partition.cpp.
References mlir::shard::Sharding::equalHaloSizes(), mlir::shard::Sharding::equalSplitAxes(), mlir::Builder::getContext(), mlir::shard::Sharding::getDynamicHaloSizes(), mlir::ImplicitLocOpBuilder::getLoc(), mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), and mlir::shard::Sharding::getStaticShardedDimsOffsets().
Referenced by reshard().
|
static |
Definition at line 188 of file Partition.cpp.
References allGatherResultShapeInUnsplitLastAxis(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInUnsplitLastAxis().
Referenced by tryUnsplitLastAxisInResharding().