|
MLIR 23.0.0git
|
Namespaces | |
| namespace | detail |
| namespace | impl |
Classes | |
| struct | ElementwiseShardingInterface |
| struct | IndependentParallelIteratorDomainShardingInterface |
| struct | OpRewritePatternWithSymbolTableCollection |
| class | Sharding |
| struct | ShardingOption |
| struct | ShardingPropagationOptions |
Typedefs | |
| using | ShardingArray = SmallVector<SmallVector<GridAxis>> |
| using | ShardingArrayRef = ArrayRef<SmallVector<GridAxis>> |
| using | GridAxis = int16_t |
| using | GridAxesAttr = DenseI16ArrayAttr |
| using | ShardShapeAttr = DenseI64ArrayAttr |
| using | HaloSizePairAttr = DenseI64ArrayAttr |
| using | UnshardedToShardedValueMap = DenseMap<Value, Value> |
Enumerations | |
| enum class | TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward } |
| This enum controls the traversal order for the sharding propagation. More... | |
Functions | |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpResult result) |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpOperand &opOperand) |
| void | partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| ShardingArray | getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
| bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators) |
| SmallVector< GridAxis > | getReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators) |
| void | partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| bool | isReductionLoop (utils::IteratorType iType) |
| template<typename T> | |
| void | removeTrailingEmptySubArray (SmallVector< SmallVector< T > > &array) |
| bool | isFullReplication (Sharding sharding) |
| shard::GridOp | getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| shard::GridOp | getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| template<typename Op> | |
| shard::GridOp | getGrid (Op op, SymbolTableCollection &symbolTableCollection) |
| template<> | |
| shard::GridOp | getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
| template<typename GridAxesRange, typename GridShapeRange> | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape) |
| template<typename GridAxesRange> | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid) |
| int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
| int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
| ShapedType | shardShapedType (ShapedType shape, GridOp grid, Sharding sharding) |
| Type | shardType (Type type, GridOp grid, Sharding sharding) |
| void | maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder) |
| void | maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder) |
| SmallVector< Value > | getMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type()) |
| Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. | |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
| void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
| std::unique_ptr<::mlir::Pass > | createPartition () |
| std::unique_ptr<::mlir::Pass > | createShardingPropagation () |
| std::unique_ptr<::mlir::Pass > | createShardingPropagation (ShardingPropagationOptions options) |
| void | registerPartition () |
| void | registerPartitionPass () |
| void | registerShardingPropagation () |
| void | registerShardingPropagationPass () |
| void | registerShardPasses () |
| template<typename AlgebraicOp> | |
| void | populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
| void | populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
| TypedValue< IndexType > | createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder) |
| TypedValue< IndexType > | createProcessLinearIndex (ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={}) |
| TypedValue< IndexType > | createProcessLinearIndex (ImplicitLocOpBuilder &builder, StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes={}) |
| template<typename SourceAxes, typename TargetAxes> | |
| static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
| static Sharding | targetShardingInSplitLastAxis (MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | splitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::optional< std::tuple< int64_t, GridAxis > > | detectSplitLastAxisInResharding (const Sharding &sourceSharding, const Sharding &targetSharding) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, SmallVector< GridAxis > > > | detectUnsplitLastAxesInResharding (const Sharding &srcSharding, const Sharding &tgtSharding) |
| static Sharding | targetShardingInUnsplitLastAxes (MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorDim, size_t numUnsplitAxes) |
| static ShapedType | allGatherResultTypeInUnsplitLastAxes (ShapedType sourceType, int64_t splitTensorDim, ArrayRef< int64_t > gridShape, ArrayRef< GridAxis > unsplitAxes) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | unsplitLastAxesInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorDim, ArrayRef< GridAxis > unsplitAxes) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUnsplitLastAxesInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, int64_t, GridAxis > > | detectMoveLastSplitAxisInResharding (const Sharding &sourceSharding, const Sharding &targetSharding) |
| static Sharding | targetShardingInMoveLastAxis (MLIRContext *ctx, const Sharding &sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static ShapedType | allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
| static SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
| static LogicalResult | partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static std::vector< Sharding > | getOperandShardings (Operation &op) |
| static std::vector< Sharding > | getResultShardings (Operation &op) |
| static LogicalResult | partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | checkFullyAnnotated (Block &block) |
| static LogicalResult | checkFullyAnnotated (Operation *op) |
| static LogicalResult | partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection) |
Definition at line 27 of file ShardOps.h.
| using mlir::shard::GridAxis = int16_t |
Definition at line 26 of file ShardOps.h.
Definition at line 29 of file ShardOps.h.
Definition at line 25 of file ShardingInterface.h.
Definition at line 26 of file ShardingInterface.h.
Definition at line 28 of file ShardOps.h.
Definition at line 534 of file Partition.cpp.
|
strong |
|
static |
Definition at line 194 of file Partition.cpp.
References gatherDimension().
Referenced by unsplitLastAxesInResharding().
|
static |
Definition at line 321 of file Partition.cpp.
References gatherDimension(), and shardDimension().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 40 of file Partition.cpp.
|
static |
Definition at line 673 of file Partition.cpp.
References checkFullyAnnotated(), mlir::Block::computeBlockNumber(), mlir::emitError(), mlir::Block::getArguments(), mlir::Region::getLoc(), mlir::Block::getParent(), mlir::Operation::getUsers(), and success().
Referenced by checkFullyAnnotated(), checkFullyAnnotated(), partitionBlock(), and partitionOperation().
|
static |
Definition at line 703 of file Partition.cpp.
References checkFullyAnnotated(), mlir::Operation::emitError(), mlir::Operation::getOpOperands(), mlir::Operation::getResults(), mlir::Operation::hasTrait(), result, and success().
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridOp | grid ) |
Definition at line 162 of file ShardOps.h.
References collectiveProcessGroupSize().
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridShapeRange && | gridShape ) |
Definition at line 146 of file ShardOps.h.
Referenced by collectiveProcessGroupSize(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
| TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize | ( | GridOp | grid, |
| ArrayRef< GridAxis > | axes, | ||
| ImplicitLocOpBuilder & | builder ) |
Definition at line 201 of file Transforms.cpp.
References mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
| std::unique_ptr<::mlir::Pass > mlir::shard::createPartition | ( | ) |
We declare an explicit private instantiation because Pass classes should only be visible by the current library.
Definition at line 80 of file Partition.cpp.
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | ImplicitLocOpBuilder & | builder, |
| StringRef | grid, | ||
| ArrayRef< GridAxis > | gridAxes = {} ) |
Definition at line 227 of file Transforms.cpp.
References createProcessLinearIndex().
Referenced by mlir::linalg::createDestinationPassingStyleInitOperand(), and createProcessLinearIndex().
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | ImplicitLocOpBuilder & | builder, |
| StringRef | grid, | ||
| ValueRange | processInGroupMultiIndex, | ||
| ArrayRef< GridAxis > | gridAxes = {} ) |
Definition at line 211 of file Transforms.cpp.
References mlir::arith::ConstantIndexOp::create().
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation | ( | ) |
Definition at line 180 of file ShardingPropagation.cpp.
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation | ( | ShardingPropagationOptions | options | ) |
Definition at line 184 of file ShardingPropagation.cpp.
References NO_RESHARDING_FOR_EXPLICIT_ANNOTATIONS, RESHARDING_FOR_EXPLICIT_ANNOTATIONS, and result.
|
static |
Definition at line 252 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 89 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 141 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryUnsplitLastAxesInResharding().
Definition at line 177 of file ShardOps.h.
Referenced by allGatherResultTypeInUnsplitLastAxes(), and allToAllResultShapeInMoveLastAxis().
| shard::GridOp mlir::shard::getGrid | ( | Op | op, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 130 of file ShardOps.h.
References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().
|
inline |
Definition at line 121 of file ShardOps.h.
References getGridOrNull().
Referenced by mlir::linalg::getGrid(), getGrid(), getGrid< ShardOp >(), reshard(), and shardedBlockArgumentTypes().
|
inline |
Definition at line 135 of file ShardOps.h.
References getGrid().
| ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators | ( | ArrayRef< Sharding > | operandShardings, |
| ArrayRef< Sharding > | resultShardings, | ||
| ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
| ArrayRef< AffineMap > | indexingMaps ) |
Definition at line 572 of file ShardingInterface.cpp.
References updateGridAxisAssignmentForLoopIterators().
|
inline |
Definition at line 113 of file ShardOps.h.
References mlir::SymbolTableCollection::lookupNearestSymbolFrom().
Referenced by getGrid(), getGridAndVerify(), and partitionTriviallyShardableOperation().
| SmallVector< Value > mlir::shard::getMixedAsValues | ( | OpBuilder | b, |
| const Location & | loc, | ||
| llvm::ArrayRef< int64_t > | statics, | ||
| ValueRange | dynamics, | ||
| Type | type = Type() ) |
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition at line 77 of file ShardOps.cpp.
References b.
Definition at line 592 of file Partition.cpp.
References mlir::Value::getDefiningOp(), mlir::Operation::getNumOperands(), mlir::Operation::getOperands(), and getOperandShardings().
Referenced by getOperandShardings(), and partitionOperation().
| SmallVector< GridAxis > mlir::shard::getReductionGridAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis > > | gridAxisAssignmentForLoopIterators ) |
Definition at line 625 of file ShardingInterface.cpp.
Referenced by mlir::linalg::partitionLinalgOpWithShardedReduction().
Definition at line 612 of file Partition.cpp.
References mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::getResults(), getResultShardings(), if(), and result.
Referenced by getResultShardings(), and partitionOperation().
Definition at line 152 of file ShardingInterface.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
Definition at line 109 of file ShardingInterface.cpp.
References getSharding(), mlir::Value::getUsers(), mlir::Value::hasOneUse(), and result.
Referenced by addShardOp(), addShardOp(), mlir::shard::detail::defaultGetShardingAnnotations(), getSharding(), and visitOp().
| bool mlir::shard::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis > > | gridAxisAssignmentForLoopIterators ) |
Definition at line 612 of file ShardingInterface.cpp.
Definition at line 106 of file ShardOps.h.
References mlir::shard::Sharding::getSplitAxes().
Referenced by isValueCompatibleWithFullReplicationSharding(), maybeInsertSourceShardingAnnotation(), and reshard().
|
inline |
Definition at line 94 of file ShardOps.h.
| void mlir::shard::maybeInsertSourceShardingAnnotation | ( | Sharding | sharding, |
| OpOperand & | operand, | ||
| OpBuilder & | builder ) |
Definition at line 352 of file ShardOps.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getDefiningOp(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Value::getType(), mlir::Operation::hasTrait(), isFullReplication(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by addShardOp().
| void mlir::shard::maybeInsertTargetShardingAnnotation | ( | Sharding | sharding, |
| OpResult | result, | ||
| OpBuilder & | builder ) |
Definition at line 338 of file ShardOps.cpp.
References maybeInsertTargetShardingAnnotationImpl(), and result.
Referenced by addShardOp().
|
static |
Definition at line 334 of file Partition.cpp.
References allToAllResultShapeInMoveLastAxis(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInMoveLastAxis().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 782 of file Partition.cpp.
References checkFullyAnnotated(), mlir::OpBuilder::createBlock(), mlir::Block::getArguments(), mlir::Block::getOperations(), mlir::Block::getParent(), mlir::IRMapping::map(), partitionBlock(), partitionOperation(), mlir::OpBuilder::setInsertionPointToEnd(), shardedBlockArgumentTypes(), and success().
Referenced by partitionBlock(), and partitionFuncOp().
| void mlir::shard::partitionFullyReplicatedOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder ) |
Definition at line 543 of file ShardingInterface.cpp.
References areValuesCompatibleWithFullReplicationShardings(), mlir::OpBuilder::clone(), mlir::Operation::getOperands(), and mlir::Operation::getResults().
Referenced by partitionOperation().
|
static |
Definition at line 813 of file Partition.cpp.
References b, mlir::Operation::getOperandTypes(), partitionBlock(), partitionFuncOp(), and success().
Referenced by partitionFuncOp().
|
static |
Definition at line 563 of file Partition.cpp.
References mlir::Operation::getResults(), partitionFullyReplicatedOperation(), partitionOperation(), result, and success().
Referenced by partitionBlock(), partitionOperation(), partitionOperation(), and partitionOperation().
|
static |
Definition at line 743 of file Partition.cpp.
References checkFullyAnnotated(), mlir::OpBuilder::clone(), mlir::Operation::emitError(), mlir::Operation::getOperands(), getOperandShardings(), mlir::Operation::getResult(), getResultShardings(), mlir::IRMapping::map(), partitionOperation(), and success().
|
static |
Definition at line 646 of file Partition.cpp.
References mlir::IRMapping::contains(), mlir::Value::getDefiningOp(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), partitionOperation(), reshard(), and success().
| void mlir::shard::partitionTriviallyShardableOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder ) |
Definition at line 638 of file ShardingInterface.cpp.
References mlir::OpBuilder::clone(), getGridOrNull(), mlir::Operation::getResults(), and shardType().
Referenced by mlir::shard::ElementwiseShardingInterface< ElemwiseOp >::partition(), and mlir::shard::IndependentParallelIteratorDomainShardingInterface< Op >::partition().
| void mlir::shard::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 189 of file Transforms.cpp.
References mlir::patterns, populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
| void mlir::shard::populateAllReduceEndomorphismSimplificationPatterns | ( | RewritePatternSet & | patterns, |
| ReductionKind | reduction ) |
Definition at line 40 of file Simplifications.h.
References mlir::patterns.
Referenced by populateSimplificationPatterns().
| void mlir::shard::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 177 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 114 of file Simplifications.cpp.
References mlir::patterns.
Referenced by populateSimplificationPatterns().
| void mlir::shard::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 167 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateSimplificationPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 23 of file Simplifications.cpp.
References mlir::patterns, populateAllReduceEndomorphismSimplificationPatterns(), and populateFoldingPatterns().
| void mlir::shard::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 195 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
| void mlir::shard::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 183 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
| void mlir::shard::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 173 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
|
inline |
|
inline |
Definition at line 242 of file Passes.h.
Referenced by mlir::registerAllPasses().
| void mlir::shard::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T > > & | array | ) |
Definition at line 100 of file ShardOps.h.
Referenced by mlir::shard::detail::defaultGetShardingOption(), getSharding(), and getSharding().
|
static |
Definition at line 456 of file Partition.cpp.
References mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), isFullReplication(), shardShapedType(), tryMoveLastSplitAxisInResharding(), trySplitLastAxisInResharding(), tryUnsplitLastAxesInResharding(), and tryUpdateHaloInResharding().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| GridOp | grid, | ||
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue ) |
Definition at line 507 of file Partition.cpp.
References reshard(), and target.
Referenced by partitionOperation(), reshard(), and reshard().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue, | ||
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 518 of file Partition.cpp.
| void mlir::shard::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 527 of file Partition.cpp.
References mlir::DialectRegistry::insert().
Definition at line 168 of file ShardOps.h.
Referenced by allToAllResultShapeInMoveLastAxis().
|
static |
Definition at line 540 of file Partition.cpp.
References mlir::Block::getArguments(), getGrid(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by partitionBlock().
| ShapedType mlir::shard::shardShapedType | ( | ShapedType | shape, |
| GridOp | grid, | ||
| Sharding | sharding ) |
Definition at line 281 of file ShardOps.cpp.
References mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), and shardShape().
Referenced by moveLastSplitAxisInResharding(), reshard(), shardedBlockArgumentTypes(), shardType(), and unsplitLastAxesInResharding().
Definition at line 291 of file ShardOps.cpp.
References shardShapedType().
Referenced by partitionTriviallyShardableOperation().
|
static |
Definition at line 69 of file Partition.cpp.
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 293 of file Partition.cpp.
References mlir::detail::DenseArrayAttrImpl< int16_t >::get(), mlir::shard::Sharding::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 47 of file Partition.cpp.
|
static |
Definition at line 176 of file Partition.cpp.
References mlir::detail::DenseArrayAttrImpl< int16_t >::get(), mlir::shard::Sharding::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by unsplitLastAxesInResharding().
|
static |
Definition at line 362 of file Partition.cpp.
References detectMoveLastSplitAxisInResharding(), and moveLastSplitAxisInResharding().
Referenced by reshard().
|
static |
Definition at line 121 of file Partition.cpp.
References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().
Referenced by reshard().
|
static |
Definition at line 230 of file Partition.cpp.
References detectUnsplitLastAxesInResharding(), and unsplitLastAxesInResharding().
Referenced by reshard().
|
static |
Definition at line 383 of file Partition.cpp.
References mlir::shard::Sharding::equalHaloSizes(), mlir::shard::Sharding::equalSplitAxes(), mlir::Builder::getContext(), mlir::shard::Sharding::getDynamicHaloSizes(), mlir::ImplicitLocOpBuilder::getLoc(), mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), and mlir::shard::Sharding::getStaticShardedDimsOffsets().
Referenced by reshard().
|
static |
Definition at line 206 of file Partition.cpp.
References allGatherResultTypeInUnsplitLastAxes(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInUnsplitLastAxes().
Referenced by tryUnsplitLastAxesInResharding().