|
MLIR
22.0.0git
|
Namespaces | |
| detail | |
Classes | |
| struct | ShardingOption |
| struct | IndependentParallelIteratorDomainShardingInterface |
| struct | ElementwiseShardingInterface |
| class | Sharding |
| struct | OpRewritePatternWithSymbolTableCollection |
Typedefs | |
| using | ShardingArray = SmallVector< SmallVector< GridAxis > > |
| using | ShardingArrayRef = ArrayRef< SmallVector< GridAxis > > |
| using | GridAxis = int16_t |
| using | GridAxesAttr = DenseI16ArrayAttr |
| using | ShardShapeAttr = DenseI64ArrayAttr |
| using | HaloSizePairAttr = DenseI64ArrayAttr |
| using | UnshardedToShardedValueMap = DenseMap< Value, Value > |
Enumerations | |
| enum class | TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward } |
| This enum controls the traversal order for the sharding propagation. More... | |
Functions | |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpResult result) |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpOperand &opOperand) |
| void | partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| ShardingArray | getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
| bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators) |
| SmallVector< GridAxis > | getReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators) |
| void | partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| bool | isReductionLoop (utils::IteratorType iType) |
| template<typename T > | |
| void | removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array) |
| bool | isFullReplication (Sharding sharding) |
| shard::GridOp | getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| shard::GridOp | getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| template<typename Op > | |
| shard::GridOp | getGrid (Op op, SymbolTableCollection &symbolTableCollection) |
| template<> | |
| shard::GridOp | getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
| template<typename GridAxesRange , typename GridShapeRange > | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape) |
| template<typename GridAxesRange > | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid) |
| int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
| int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
| ShapedType | shardShapedType (ShapedType shape, GridOp grid, Sharding sharding) |
| Type | shardType (Type type, GridOp grid, Sharding sharding) |
| void | maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder) |
| void | maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder) |
| SmallVector< Value > | getMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type()) |
| Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. More... | |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
| void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
| template<typename AlgebraicOp > | |
| void | populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
| void | populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
| TypedValue< IndexType > | createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder) |
| TypedValue< IndexType > | createProcessLinearIndex (StringRef grid, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder) |
| TypedValue< IndexType > | createProcessLinearIndex (StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes, ImplicitLocOpBuilder &builder) |
| template<typename SourceAxes , typename TargetAxes > | |
| static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
| static Sharding | targetShardingInSplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | splitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::optional< std::tuple< int64_t, GridAxis > > | detectSplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, GridAxis > > | detectUnsplitLastAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
| static Sharding | targetShardingInUnsplitLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis) |
| static ShapedType | allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< int64_t, int64_t, GridAxis > > | detectMoveLastSplitAxisInResharding (Sharding sourceSharding, Sharding targetSharding) |
| static Sharding | targetShardingInMoveLastAxis (MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static ShapedType | allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
| static std::tuple< TypedValue< ShapedType >, Sharding > | moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > | tryUpdateHaloInResharding (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
| static TypedValue< ShapedType > | reshardOn1DGrid (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
| static TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
| static SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
| static LogicalResult | partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static std::vector< Sharding > | getOperandShardings (Operation &op) |
| static std::vector< Sharding > | getResultShardings (Operation &op) |
| static LogicalResult | partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection) |
| using mlir::shard::GridAxesAttr = typedef DenseI16ArrayAttr |
Definition at line 27 of file ShardOps.h.
| using mlir::shard::GridAxis = typedef int16_t |
Definition at line 26 of file ShardOps.h.
| using mlir::shard::HaloSizePairAttr = typedef DenseI64ArrayAttr |
Definition at line 29 of file ShardOps.h.
| using mlir::shard::ShardingArray = typedef SmallVector<SmallVector<GridAxis> > |
Definition at line 25 of file ShardingInterface.h.
| using mlir::shard::ShardingArrayRef = typedef ArrayRef<SmallVector<GridAxis> > |
Definition at line 26 of file ShardingInterface.h.
| using mlir::shard::ShardShapeAttr = typedef DenseI64ArrayAttr |
Definition at line 28 of file ShardOps.h.
| using mlir::shard::UnshardedToShardedValueMap = typedef DenseMap<Value, Value> |
Definition at line 531 of file Partition.cpp.
|
strong |
|
static |
Definition at line 180 of file Partition.cpp.
References gatherDimension().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 303 of file Partition.cpp.
References gatherDimension(), and shardDimension().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 39 of file Partition.cpp.
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridOp | grid | ||
| ) |
Definition at line 162 of file ShardOps.h.
References collectiveProcessGroupSize().
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridShapeRange && | gridShape | ||
| ) |
Definition at line 146 of file ShardOps.h.
Referenced by collectiveProcessGroupSize(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
| TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize | ( | GridOp | grid, |
| ArrayRef< GridAxis > | axes, | ||
| ImplicitLocOpBuilder & | builder | ||
| ) |
Definition at line 202 of file Transforms.cpp.
References mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | StringRef | grid, |
| ArrayRef< GridAxis > | gridAxes, | ||
| ImplicitLocOpBuilder & | builder | ||
| ) |
Definition at line 228 of file Transforms.cpp.
Referenced by mlir::linalg::createDestinationPassingStyleInitOperand().
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | StringRef | grid, |
| ValueRange | processInGroupMultiIndex, | ||
| ArrayRef< GridAxis > | gridAxes, | ||
| ImplicitLocOpBuilder & | builder | ||
| ) |
Definition at line 212 of file Transforms.cpp.
References mlir::arith::ConstantIndexOp::create(), and mlir::affine::linearizeIndex().
|
static |
Definition at line 234 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 87 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 136 of file Partition.cpp.
References mlir::shard::Sharding::getSplitAxes().
Referenced by tryUnsplitLastAxisInResharding().
|
inline |
Definition at line 177 of file ShardOps.h.
Referenced by allGatherResultShapeInUnsplitLastAxis(), and allToAllResultShapeInMoveLastAxis().
| shard::GridOp mlir::shard::getGrid | ( | Op | op, |
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 130 of file ShardOps.h.
References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().
|
inline |
Definition at line 121 of file ShardOps.h.
References getGridOrNull().
Referenced by getGrid(), mlir::linalg::getGrid(), getGrid< ShardOp >(), reshard(), and shardedBlockArgumentTypes().
|
inline |
Definition at line 135 of file ShardOps.h.
References getGrid().
| ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators | ( | ArrayRef< Sharding > | operandShardings, |
| ArrayRef< Sharding > | resultShardings, | ||
| ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
| ArrayRef< AffineMap > | indexingMaps | ||
| ) |
Definition at line 572 of file ShardingInterface.cpp.
References updateGridAxisAssignmentForLoopIterators().
|
inline |
Definition at line 113 of file ShardOps.h.
References mlir::SymbolTableCollection::lookupNearestSymbolFrom().
Referenced by getGrid(), getGridAndVerify(), and partitionTriviallyShardableOperation().
| SmallVector< Value > mlir::shard::getMixedAsValues | ( | OpBuilder | b, |
| const Location & | loc, | ||
| llvm::ArrayRef< int64_t > | statics, | ||
| ValueRange | dynamics, | ||
| Type | type = Type() |
||
| ) |
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition at line 77 of file ShardOps.cpp.
References mlir::Builder::getI64IntegerAttr(), mlir::Builder::getI64Type(), mlir::Builder::getIndexAttr(), and mlir::Builder::getIndexType().
Definition at line 589 of file Partition.cpp.
References mlir::Value::getDefiningOp(), mlir::Operation::getNumOperands(), and mlir::Operation::getOperands().
Referenced by partitionOperation().
| SmallVector< GridAxis > mlir::shard::getReductionGridAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis >> | gridAxisAssignmentForLoopIterators | ||
| ) |
Definition at line 625 of file ShardingInterface.cpp.
Referenced by mlir::linalg::partitionLinalgOpWithShardedReduction().
Definition at line 609 of file Partition.cpp.
References mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::getResults(), and mlir::Value::getUsers().
Referenced by partitionOperation().
Definition at line 152 of file ShardingInterface.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
Definition at line 109 of file ShardingInterface.cpp.
References mlir::Value::getUsers(), and mlir::Value::hasOneUse().
Referenced by addShardOp(), mlir::shard::detail::defaultGetShardingAnnotations(), and visitOp().
| bool mlir::shard::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis >> | gridAxisAssignmentForLoopIterators | ||
| ) |
Definition at line 612 of file ShardingInterface.cpp.
|
inline |
Definition at line 106 of file ShardOps.h.
References mlir::shard::Sharding::getSplitAxes().
Referenced by isValueCompatibleWithFullReplicationSharding(), maybeInsertSourceShardingAnnotation(), and reshard().
|
inline |
Definition at line 94 of file ShardOps.h.
| void mlir::shard::maybeInsertSourceShardingAnnotation | ( | Sharding | sharding, |
| OpOperand & | operand, | ||
| OpBuilder & | builder | ||
| ) |
Definition at line 352 of file ShardOps.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getDefiningOp(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Value::getType(), mlir::Operation::hasTrait(), isFullReplication(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by addShardOp().
| void mlir::shard::maybeInsertTargetShardingAnnotation | ( | Sharding | sharding, |
| OpResult | result, | ||
| OpBuilder & | builder | ||
| ) |
Definition at line 338 of file ShardOps.cpp.
References mlir::Value::getUses(), and maybeInsertTargetShardingAnnotationImpl().
Referenced by addShardOp().
|
static |
Definition at line 316 of file Partition.cpp.
References allToAllResultShapeInMoveLastAxis(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInMoveLastAxis().
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 701 of file Partition.cpp.
References mlir::OpBuilder::createBlock(), mlir::remark::failed(), mlir::Block::getArguments(), mlir::Block::getOperations(), mlir::Block::getParent(), mlir::IRMapping::map(), partitionOperation(), mlir::OpBuilder::setInsertionPointToEnd(), and shardedBlockArgumentTypes().
Referenced by partitionFuncOp().
| void mlir::shard::partitionFullyReplicatedOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder | ||
| ) |
Definition at line 543 of file ShardingInterface.cpp.
References areValuesCompatibleWithFullReplicationShardings(), mlir::OpBuilder::clone(), mlir::Operation::getOperands(), and mlir::Operation::getResults().
Referenced by partitionOperation().
|
static |
Definition at line 729 of file Partition.cpp.
References mlir::remark::failed(), mlir::get(), mlir::Operation::getOperandTypes(), and partitionBlock().
|
static |
Definition at line 560 of file Partition.cpp.
References mlir::remark::failed(), mlir::Operation::getResults(), and partitionFullyReplicatedOperation().
|
static |
Definition at line 667 of file Partition.cpp.
References mlir::OpBuilder::clone(), mlir::Operation::emitError(), mlir::Operation::getOperands(), getOperandShardings(), mlir::Operation::getResult(), getResultShardings(), and mlir::IRMapping::map().
Referenced by partitionBlock().
|
static |
Definition at line 643 of file Partition.cpp.
References mlir::IRMapping::contains(), mlir::Value::getDefiningOp(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), and reshard().
| void mlir::shard::partitionTriviallyShardableOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder | ||
| ) |
Definition at line 638 of file ShardingInterface.cpp.
References mlir::OpBuilder::clone(), getGridOrNull(), mlir::Operation::getResults(), and shardType().
Referenced by mlir::shard::IndependentParallelIteratorDomainShardingInterface< Op >::partition(), mlir::shard::ElementwiseShardingInterface< ElemwiseOp >::partition(), and mlir::linalg::partitionLinalgOpWithShardedReduction().
| void mlir::shard::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 190 of file Transforms.cpp.
References mlir::patterns, populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
| void mlir::shard::populateAllReduceEndomorphismSimplificationPatterns | ( | RewritePatternSet & | patterns, |
| ReductionKind | reduction | ||
| ) |
Definition at line 40 of file Simplifications.h.
References mlir::patterns.
| void mlir::shard::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 178 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 114 of file Simplifications.cpp.
References mlir::patterns.
Referenced by populateSimplificationPatterns().
| void mlir::shard::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 168 of file Transforms.cpp.
References mlir::patterns.
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateSimplificationPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 23 of file Simplifications.cpp.
References mlir::patterns, and populateFoldingPatterns().
| void mlir::shard::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 196 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
| void mlir::shard::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 184 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
| void mlir::shard::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 174 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
| void mlir::shard::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T >> & | array | ) |
Definition at line 100 of file ShardOps.h.
Referenced by mlir::shard::detail::defaultGetShardingOption(), and getSharding().
|
static |
Definition at line 480 of file Partition.cpp.
References isFullReplication(), reshardOn1DGrid(), and tryUpdateHaloInResharding().
Referenced by partitionOperation().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| GridOp | grid, | ||
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue | ||
| ) |
Definition at line 504 of file Partition.cpp.
Referenced by reshard().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue, | ||
| SymbolTableCollection & | symbolTableCollection | ||
| ) |
Definition at line 515 of file Partition.cpp.
| void mlir::shard::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 524 of file Partition.cpp.
References mlir::DialectRegistry::insert().
|
static |
Definition at line 438 of file Partition.cpp.
References mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), shardShapedType(), tryMoveLastSplitAxisInResharding(), trySplitLastAxisInResharding(), and tryUnsplitLastAxisInResharding().
Referenced by reshard().
|
inline |
Definition at line 168 of file ShardOps.h.
Referenced by allToAllResultShapeInMoveLastAxis().
|
static |
Definition at line 537 of file Partition.cpp.
References mlir::Block::getArguments(), getGrid(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by partitionBlock().
| ShapedType mlir::shard::shardShapedType | ( | ShapedType | shape, |
| GridOp | grid, | ||
| Sharding | sharding | ||
| ) |
Definition at line 281 of file ShardOps.cpp.
References mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), and shardShape().
Referenced by moveLastSplitAxisInResharding(), reshardOn1DGrid(), shardedBlockArgumentTypes(), shardType(), and unsplitLastAxisInResharding().
Definition at line 291 of file ShardOps.cpp.
References shardShapedType().
Referenced by partitionTriviallyShardableOperation().
|
static |
Definition at line 68 of file Partition.cpp.
References mlir::Builder::getContext(), and targetShardingInSplitLastAxis().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 275 of file Partition.cpp.
References mlir::shard::Sharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 46 of file Partition.cpp.
References mlir::shard::Sharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by splitLastAxisInResharding().
|
static |
Definition at line 164 of file Partition.cpp.
References mlir::shard::Sharding::get(), mlir::detail::DenseArrayAttrImpl< T >::get(), mlir::shard::Sharding::getGridAttr(), and mlir::shard::Sharding::getSplitAxes().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 344 of file Partition.cpp.
References detectMoveLastSplitAxisInResharding(), and moveLastSplitAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 119 of file Partition.cpp.
References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 213 of file Partition.cpp.
References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().
Referenced by reshardOn1DGrid().
|
static |
Definition at line 365 of file Partition.cpp.
References mlir::shard::Sharding::equalHaloSizes(), mlir::shard::Sharding::equalSplitAxes(), mlir::get(), mlir::Builder::getContext(), mlir::shard::Sharding::getDynamicHaloSizes(), mlir::ImplicitLocOpBuilder::getLoc(), mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), and mlir::shard::Sharding::getStaticShardedDimsOffsets().
Referenced by reshard().
|
static |
Definition at line 188 of file Partition.cpp.
References allGatherResultShapeInUnsplitLastAxis(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInUnsplitLastAxis().
Referenced by tryUnsplitLastAxisInResharding().