|
MLIR 23.0.0git
|
Namespaces | |
| namespace | detail |
| namespace | impl |
Classes | |
| struct | ElementwiseShardingInterface |
| struct | IndependentParallelIteratorDomainShardingInterface |
| class | MoveSplitAxisPattern |
| Move a split axis between tensor dimensions: e.g. More... | |
| struct | OpRewritePatternWithSymbolTableCollection |
| class | ReshardingPattern |
| Base class for resharding patterns. More... | |
| class | Sharding |
| struct | ShardingOption |
| struct | ShardingPropagationOptions |
| class | SplitLastAxisPattern |
| Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]]. More... | |
| class | UnsplitLastAxesPattern |
| Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> []. More... | |
| class | UpdateHaloPattern |
| Update halo sizes: handles cases where only the halo sizes differ between source and target sharding. More... | |
Typedefs | |
| using | ShardingArray = SmallVector<SmallVector<GridAxis>> |
| using | ShardingArrayRef = ArrayRef<SmallVector<GridAxis>> |
| using | GridAxis = int16_t |
| using | GridAxesAttr = DenseI16ArrayAttr |
| using | ShardShapeAttr = DenseI64ArrayAttr |
| using | HaloSizePairAttr = DenseI64ArrayAttr |
| using | UnshardedToShardedValueMap = DenseMap<Value, Value> |
Enumerations | |
| enum class | TraversalOrder { Forward , Backward , ForwardBackward , BackwardForward } |
| This enum controls the traversal order for the sharding propagation. More... | |
Functions | |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpResult result) |
| FailureOr< std::pair< bool, Sharding > > | getSharding (OpOperand &opOperand) |
| void | partitionFullyReplicatedOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| ShardingArray | getGridAxisAssignmentForLoopIterators (ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
| bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators) |
| SmallVector< GridAxis > | getReductionGridAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators) |
| void | partitionTriviallyShardableOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
| llvm::raw_ostream & | operator<< (llvm::raw_ostream &os, const Sharding &sharding) |
| Diagnostic & | operator<< (Diagnostic &diag, const Sharding &sharding) |
| bool | isReductionLoop (utils::IteratorType iType) |
| template<typename T> | |
| void | removeTrailingEmptySubArray (SmallVector< SmallVector< T > > &array) |
| bool | isFullReplication (Sharding sharding) |
| shard::GridOp | getGridOrNull (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| shard::GridOp | getGrid (Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection) |
| template<typename Op> | |
| shard::GridOp | getGrid (Op op, SymbolTableCollection &symbolTableCollection) |
| template<> | |
| shard::GridOp | getGrid< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
| template<typename GridAxesRange, typename GridShapeRange> | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridShapeRange &&gridShape) |
| template<typename GridAxesRange> | |
| int64_t | collectiveProcessGroupSize (GridAxesRange &&gridAxes, GridOp grid) |
| int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
| int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
| ShapedType | shardShapedType (ShapedType shape, GridOp grid, Sharding sharding) |
| Type | shardType (Type type, GridOp grid, Sharding sharding) |
| void | maybeInsertTargetShardingAnnotation (Sharding sharding, OpResult result, OpBuilder &builder) |
| void | maybeInsertSourceShardingAnnotation (Sharding sharding, OpOperand &operand, OpBuilder &builder) |
| SmallVector< Value > | getMixedAsValues (OpBuilder b, const Location &loc, llvm::ArrayRef< int64_t > statics, ValueRange dynamics, Type type=Type()) |
| Converts a vector of OpFoldResults (ints) into vector of Values of the provided type. | |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
| TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
| void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
| std::unique_ptr<::mlir::Pass > | createPartition () |
| std::unique_ptr<::mlir::Pass > | createShardSimplify () |
| std::unique_ptr<::mlir::Pass > | createShardingPropagation () |
| std::unique_ptr<::mlir::Pass > | createShardingPropagation (ShardingPropagationOptions options) |
| void | registerPartition () |
| void | registerPartitionPass () |
| void | registerShardSimplify () |
| void | registerShardSimplifyPass () |
| void | registerShardingPropagation () |
| void | registerShardingPropagationPass () |
| void | registerShardPasses () |
| template<typename AlgebraicOp> | |
| void | populateAllReduceEndomorphismSimplifyPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
| void | populateSimplifyPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
| void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
| void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
| TypedValue< IndexType > | createCollectiveProcessGroupSize (GridOp grid, ArrayRef< GridAxis > axes, ImplicitLocOpBuilder &builder) |
| TypedValue< IndexType > | createProcessLinearIndex (ImplicitLocOpBuilder &builder, StringRef grid, ArrayRef< GridAxis > gridAxes={}) |
| TypedValue< IndexType > | createProcessLinearIndex (ImplicitLocOpBuilder &builder, StringRef grid, ValueRange processInGroupMultiIndex, ArrayRef< GridAxis > gridAxes={}) |
| template<typename SourceAxes, typename TargetAxes> | |
| static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
| static TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &srcSharding, const Sharding &tgtSharding, TypedValue< ShapedType > unshardedSrc, TypedValue< ShapedType > shardedSrc) |
| static SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
| static LogicalResult | partitionOperation (Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static std::vector< Sharding > | getOperandShardings (Operation &op) |
| static std::vector< Sharding > | getResultShardings (Operation &op) |
| static LogicalResult | partitionOperation (ShardOp shardOp, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | checkFullyAnnotated (Block &block) |
| static LogicalResult | checkFullyAnnotated (Operation *op) |
| static LogicalResult | partitionOperation (Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionBlock (Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
| static LogicalResult | partitionFuncOp (FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection) |
Definition at line 28 of file ShardOps.h.
| using mlir::shard::GridAxis = int16_t |
Definition at line 27 of file ShardOps.h.
Definition at line 30 of file ShardOps.h.
Definition at line 25 of file ShardingInterface.h.
Definition at line 26 of file ShardingInterface.h.
Definition at line 29 of file ShardOps.h.
Definition at line 522 of file Partition.cpp.
|
strong |
|
static |
Definition at line 42 of file Partition.cpp.
|
static |
Definition at line 667 of file Partition.cpp.
References checkFullyAnnotated(), mlir::Block::computeBlockNumber(), mlir::emitError(), mlir::Block::getArguments(), mlir::Region::getLoc(), mlir::Block::getParent(), mlir::Operation::getUsers(), and success().
Referenced by checkFullyAnnotated(), checkFullyAnnotated(), partitionBlock(), and partitionOperation().
|
static |
Definition at line 699 of file Partition.cpp.
References checkFullyAnnotated(), mlir::Operation::emitError(), mlir::Operation::getOpOperands(), mlir::Operation::getResults(), mlir::Operation::hasTrait(), result, and success().
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridOp | grid ) |
Definition at line 172 of file ShardOps.h.
References collectiveProcessGroupSize().
| int64_t mlir::shard::collectiveProcessGroupSize | ( | GridAxesRange && | gridAxes, |
| GridShapeRange && | gridShape ) |
Definition at line 156 of file ShardOps.h.
Referenced by collectiveProcessGroupSize(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
| TypedValue< IndexType > mlir::shard::createCollectiveProcessGroupSize | ( | GridOp | grid, |
| ArrayRef< GridAxis > | axes, | ||
| ImplicitLocOpBuilder & | builder ) |
Definition at line 201 of file Transforms.cpp.
References mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
| std::unique_ptr<::mlir::Pass > mlir::shard::createPartition | ( | ) |
We declare an explicit private instantiation because Pass classes should only be visible by the current library.
Definition at line 81 of file Partition.cpp.
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | ImplicitLocOpBuilder & | builder, |
| StringRef | grid, | ||
| ArrayRef< GridAxis > | gridAxes = {} ) |
Definition at line 227 of file Transforms.cpp.
References createProcessLinearIndex().
Referenced by mlir::linalg::createDestinationPassingStyleInitOperand(), and createProcessLinearIndex().
| TypedValue< IndexType > mlir::shard::createProcessLinearIndex | ( | ImplicitLocOpBuilder & | builder, |
| StringRef | grid, | ||
| ValueRange | processInGroupMultiIndex, | ||
| ArrayRef< GridAxis > | gridAxes = {} ) |
Definition at line 211 of file Transforms.cpp.
References mlir::arith::ConstantIndexOp::create().
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation | ( | ) |
Definition at line 257 of file ShardingPropagation.cpp.
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardingPropagation | ( | ShardingPropagationOptions | options | ) |
Definition at line 261 of file ShardingPropagation.cpp.
References b.
| std::unique_ptr<::mlir::Pass > mlir::shard::createShardSimplify | ( | ) |
We declare an explicit private instantiation because Pass classes should only be visible by the current library.
Definition at line 157 of file Simplify.cpp.
Definition at line 187 of file ShardOps.h.
| shard::GridOp mlir::shard::getGrid | ( | Op | op, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 140 of file ShardOps.h.
References getGrid(), and mlir::Op< ConcreteType, Traits >::getOperation().
|
inline |
Definition at line 131 of file ShardOps.h.
References getGridOrNull().
Referenced by mlir::linalg::getGrid(), getGrid(), getGrid< ShardOp >(), reshard(), and shardedBlockArgumentTypes().
|
inline |
Definition at line 145 of file ShardOps.h.
References getGrid().
| ShardingArray mlir::shard::getGridAxisAssignmentForLoopIterators | ( | ArrayRef< Sharding > | operandShardings, |
| ArrayRef< Sharding > | resultShardings, | ||
| ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
| ArrayRef< AffineMap > | indexingMaps ) |
Definition at line 572 of file ShardingInterface.cpp.
References updateGridAxisAssignmentForLoopIterators().
|
inline |
Definition at line 123 of file ShardOps.h.
References mlir::SymbolTableCollection::lookupNearestSymbolFrom().
Referenced by getGrid(), getGridAndVerify(), and partitionTriviallyShardableOperation().
| SmallVector< Value > mlir::shard::getMixedAsValues | ( | OpBuilder | b, |
| const Location & | loc, | ||
| llvm::ArrayRef< int64_t > | statics, | ||
| ValueRange | dynamics, | ||
| Type | type = Type() ) |
Converts a vector of OpFoldResults (ints) into vector of Values of the provided type.
Definition at line 77 of file ShardOps.cpp.
References b.
Definition at line 581 of file Partition.cpp.
References mlir::Value::getDefiningOp(), mlir::Operation::getNumOperands(), mlir::Operation::getOperands(), and getOperandShardings().
Referenced by getOperandShardings(), and partitionOperation().
| SmallVector< GridAxis > mlir::shard::getReductionGridAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis > > | gridAxisAssignmentForLoopIterators ) |
Definition at line 625 of file ShardingInterface.cpp.
Referenced by mlir::linalg::partitionLinalgOpWithShardedReduction().
Definition at line 601 of file Partition.cpp.
References mlir::Operation::getNumResults(), mlir::Operation::getOperands(), mlir::Operation::getResults(), getResultShardings(), if(), and result.
Referenced by getResultShardings(), and partitionOperation().
Definition at line 152 of file ShardingInterface.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
Definition at line 109 of file ShardingInterface.cpp.
References getSharding(), mlir::Value::getUsers(), mlir::Value::hasOneUse(), and result.
Referenced by addShardOp(), addShardOp(), mlir::shard::detail::defaultGetShardingAnnotations(), getSharding(), and visitOp().
| bool mlir::shard::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
| ArrayRef< SmallVector< GridAxis > > | gridAxisAssignmentForLoopIterators ) |
Definition at line 612 of file ShardingInterface.cpp.
Definition at line 116 of file ShardOps.h.
References mlir::shard::Sharding::getSplitAxes().
Referenced by isValueCompatibleWithFullReplicationSharding(), maybeInsertSourceShardingAnnotation(), and reshard().
|
inline |
Definition at line 104 of file ShardOps.h.
| void mlir::shard::maybeInsertSourceShardingAnnotation | ( | Sharding | sharding, |
| OpOperand & | operand, | ||
| OpBuilder & | builder ) |
Definition at line 352 of file ShardOps.cpp.
References mlir::IROperand< DerivedT, IRValueT >::get(), mlir::Value::getDefiningOp(), mlir::Value::getLoc(), mlir::detail::IROperandBase::getOwner(), mlir::Value::getType(), mlir::Operation::hasTrait(), isFullReplication(), mlir::RewriterBase::replaceUsesWithIf(), and mlir::OpBuilder::setInsertionPoint().
Referenced by addShardOp().
| void mlir::shard::maybeInsertTargetShardingAnnotation | ( | Sharding | sharding, |
| OpResult | result, | ||
| OpBuilder & | builder ) |
Definition at line 338 of file ShardOps.cpp.
References maybeInsertTargetShardingAnnotationImpl(), and result.
Referenced by addShardOp().
|
inline |
Definition at line 85 of file ShardOps.h.
References diag().
| llvm::raw_ostream & mlir::shard::operator<< | ( | llvm::raw_ostream & | os, |
| const Sharding & | sharding ) |
Definition at line 751 of file ShardOps.cpp.
References mlir::shard::Sharding::getGrid(), mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), and mlir::shard::Sharding::getStaticShardedDimsOffsets().
|
static |
Definition at line 778 of file Partition.cpp.
References checkFullyAnnotated(), mlir::OpBuilder::createBlock(), mlir::Block::getArguments(), mlir::Block::getOperations(), mlir::Block::getParent(), mlir::IRMapping::map(), partitionBlock(), partitionOperation(), mlir::OpBuilder::setInsertionPointToEnd(), shardedBlockArgumentTypes(), and success().
Referenced by partitionBlock(), and partitionFuncOp().
| void mlir::shard::partitionFullyReplicatedOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder ) |
Definition at line 543 of file ShardingInterface.cpp.
References areValuesCompatibleWithFullReplicationShardings(), mlir::OpBuilder::clone(), mlir::Operation::getOperands(), and mlir::Operation::getResults().
Referenced by partitionOperation().
|
static |
Definition at line 809 of file Partition.cpp.
References b, mlir::Operation::getOperandTypes(), partitionBlock(), partitionFuncOp(), and success().
Referenced by partitionFuncOp().
|
static |
Definition at line 552 of file Partition.cpp.
References mlir::Operation::getResults(), partitionFullyReplicatedOperation(), partitionOperation(), result, and success().
Referenced by partitionBlock(), partitionOperation(), partitionOperation(), and partitionOperation().
|
static |
Definition at line 739 of file Partition.cpp.
References checkFullyAnnotated(), mlir::OpBuilder::clone(), mlir::Operation::emitError(), mlir::Operation::getOperands(), getOperandShardings(), mlir::Operation::getResult(), getResultShardings(), mlir::IRMapping::map(), partitionOperation(), and success().
|
static |
Definition at line 635 of file Partition.cpp.
References mlir::IRMapping::contains(), mlir::Value::getDefiningOp(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), partitionOperation(), reshard(), and success().
| void mlir::shard::partitionTriviallyShardableOperation | ( | Operation & | op, |
| ArrayRef< Value > | partitionedOperands, | ||
| ArrayRef< Sharding > | operandShardings, | ||
| ArrayRef< Sharding > | resultShardings, | ||
| IRMapping & | partitionMap, | ||
| SymbolTableCollection & | symbolTable, | ||
| OpBuilder & | builder ) |
Definition at line 638 of file ShardingInterface.cpp.
References mlir::OpBuilder::clone(), getGridOrNull(), mlir::Operation::getResults(), and shardType().
Referenced by mlir::shard::ElementwiseShardingInterface< ElemwiseOp >::partition(), and mlir::shard::IndependentParallelIteratorDomainShardingInterface< Op >::partition().
| void mlir::shard::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 189 of file Transforms.cpp.
References populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
| void mlir::shard::populateAllReduceEndomorphismSimplifyPatterns | ( | RewritePatternSet & | patterns, |
| ReductionKind | reduction ) |
Definition at line 40 of file Simplify.h.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
| void mlir::shard::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 177 of file Transforms.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 164 of file Simplify.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
| void mlir::shard::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 167 of file Transforms.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateAllOpLoweringPatterns().
| void mlir::shard::populateSimplifyPatterns | ( | RewritePatternSet & | patterns, |
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 136 of file Simplify.cpp.
| void mlir::shard::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 195 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
| void mlir::shard::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 183 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
| void mlir::shard::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 173 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
|
inline |
|
inline |
Definition at line 341 of file Passes.h.
Referenced by mlir::registerAllPasses().
| void mlir::shard::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T > > & | array | ) |
Definition at line 110 of file ShardOps.h.
Referenced by mlir::shard::detail::defaultGetShardingOption(), getSharding(), and getSharding().
|
static |
Definition at line 443 of file Partition.cpp.
References mlir::ImplicitLocOpBuilder::emitError(), isFullReplication(), and shardShapedType().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| GridOp | grid, | ||
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue ) |
Definition at line 495 of file Partition.cpp.
References reshard().
Referenced by partitionOperation(), reshard(), and reshard().
| TypedValue< ShapedType > mlir::shard::reshard | ( | OpBuilder & | builder, |
| ShardOp | source, | ||
| ShardOp | target, | ||
| TypedValue< ShapedType > | sourceShardValue, | ||
| SymbolTableCollection & | symbolTableCollection ) |
Definition at line 506 of file Partition.cpp.
| void mlir::shard::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 515 of file Partition.cpp.
References mlir::DialectRegistry::insert().
Definition at line 178 of file ShardOps.h.
|
static |
Definition at line 528 of file Partition.cpp.
References mlir::Block::getArguments(), getGrid(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by partitionBlock().
| ShapedType mlir::shard::shardShapedType | ( | ShapedType | shape, |
| GridOp | grid, | ||
| Sharding | sharding ) |
Definition at line 281 of file ShardOps.cpp.
References mlir::shard::Sharding::getSplitAxes(), mlir::shard::Sharding::getStaticHaloSizes(), mlir::shard::Sharding::getStaticShardedDimsOffsets(), and shardShape().
Referenced by reshard(), shardedBlockArgumentTypes(), and shardType().
Definition at line 291 of file ShardOps.cpp.
References shardShapedType().
Referenced by partitionTriviallyShardableOperation().