MLIR
19.0.0git
|
Namespaces | |
detail | |
Classes | |
struct | ShardingOption |
struct | IndependentParallelIteratorDomainShardingInterface |
struct | ElementwiseShardingInterface |
struct | OpRewritePatternWithSymbolTableCollection |
Typedefs | |
using | ShardingArray = SmallVector< SmallVector< MeshAxis > > |
using | ShardingArrayRef = ArrayRef< SmallVector< MeshAxis > > |
using | MeshAxis = int16_t |
using | MeshAxesAttr = DenseI16ArrayAttr |
using | UnshardedToShardedValueMap = DenseMap< Value, Value > |
Functions | |
FailureOr< std::pair< bool, MeshShardingAttr > > | getMeshShardingAttr (OpResult result) |
FailureOr< std::pair< bool, MeshShardingAttr > > | getMeshShardingAttr (OpOperand &opOperand) |
void | spmdizeFullyReplicatedOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
ShardingArray | getMeshAxisAssignmentForLoopIterators (ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps) |
bool | isAtLeastOneReductionIteratorSharded (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators) |
SmallVector< MeshAxis > | getReductionMeshAxes (ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators) |
void | spmdizeTriviallyShardableOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) |
bool | isReductionLoop (utils::IteratorType iType) |
template<typename T > | |
void | removeTrailingEmptySubArray (SmallVector< SmallVector< T >> &array) |
bool | isFullReplication (MeshShardingAttr attr) |
mesh::MeshOp | getMesh (Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection) |
template<typename Op > | |
mesh::MeshOp | getMesh (Op op, SymbolTableCollection &symbolTableCollection) |
template<> | |
mesh::MeshOp | getMesh< ShardOp > (ShardOp op, SymbolTableCollection &symbolTableCollection) |
template<typename MeshAxesRange , typename MeshShapeRange > | |
int64_t | collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape) |
template<typename MeshAxesRange > | |
int64_t | collectiveProcessGroupSize (MeshAxesRange &&meshAxes, MeshOp mesh) |
int64_t | shardDimension (int64_t dimSize, int64_t shardCount) |
int64_t | gatherDimension (int64_t dimSize, int64_t shardCount) |
ShapedType | shardShapedType (ShapedType shape, MeshOp mesh, MeshShardingAttr sharding) |
Type | shardType (Type type, MeshOp mesh, MeshShardingAttr sharding) |
template<typename AlgebraicOp > | |
void | populateAllReduceEndomorphismSimplificationPatterns (RewritePatternSet &patterns, ReductionKind reduction) |
void | populateSimplificationPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | populateFoldingPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
TypedValue< ShapedType > | reshard (OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue) |
TypedValue< ShapedType > | reshard (OpBuilder &builder, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue, SymbolTableCollection &symbolTableCollection) |
void | reshardingRegisterDependentDialects (DialectRegistry ®istry) |
void | populateProcessMultiIndexOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerProcessMultiIndexOpLoweringDialects (DialectRegistry ®istry) |
void | populateAllSliceOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerAllSliceOpLoweringDialects (DialectRegistry ®istry) |
void | populateAllOpLoweringPatterns (RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) |
void | registerAllOpLoweringDialects (DialectRegistry ®istry) |
TypedValue< IndexType > | createCollectiveProcessGroupSize (MeshOp mesh, ArrayRef< MeshAxis > axes, ImplicitLocOpBuilder &builder) |
TypedValue< IndexType > | createProcessLinearIndex (StringRef mesh, ArrayRef< MeshAxis > meshAxes, ImplicitLocOpBuilder &builder) |
template<typename SourceAxes , typename TargetAxes > | |
static bool | arePartialAxesCompatible (const SourceAxes &sourceAxes, const TargetAxes &targetAxes) |
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | handlePartialAxesDuringResharding (OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard) |
static MeshShardingAttr | targetShardingInSplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | splitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
static std::optional< std::tuple< int64_t, MeshAxis > > | detectSplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > | trySplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< int64_t, MeshAxis > > | detectUnsplitLastAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) |
static MeshShardingAttr | targetShardingInUnsplitLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis) |
static ShapedType | allGatherResultShapeInUnsplitLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) |
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | unsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > | tryUnsplitLastAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > | detectMoveLastSplitAxisInResharding (MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding) |
static MeshShardingAttr | targetShardingInMoveLastAxis (MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
static ShapedType | allToAllResultShapeInMoveLastAxis (ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis) |
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > | moveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis) |
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > | tryMoveLastSplitAxisInResharding (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard) |
static TypedValue< ShapedType > | reshardOn1DMesh (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
TypedValue< ShapedType > | reshard (ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard) |
SmallVector< Type > | shardedBlockArgumentTypes (Block &block, SymbolTableCollection &symbolTableCollection) |
static LogicalResult | spmdizeOperation (Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static SmallVector< MeshShardingAttr > | getOperandShardings (Operation &op) |
static SmallVector< MeshShardingAttr > | getResultShardings (Operation &op) |
static LogicalResult | spmdizeOperation (ShardOp shardOp, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | spmdizeOperation (Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | spmdizeBlock (Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) |
static LogicalResult | spmdizeFuncOp (FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection) |
using mlir::mesh::MeshAxesAttr = typedef DenseI16ArrayAttr |
using mlir::mesh::MeshAxis = typedef int16_t |
using mlir::mesh::ShardingArray = typedef SmallVector<SmallVector<MeshAxis> > |
Definition at line 25 of file ShardingInterface.h.
using mlir::mesh::ShardingArrayRef = typedef ArrayRef<SmallVector<MeshAxis> > |
Definition at line 26 of file ShardingInterface.h.
using mlir::mesh::UnshardedToShardedValueMap = typedef DenseMap<Value, Value> |
Definition at line 521 of file Spmdization.cpp.
|
static |
Definition at line 251 of file Spmdization.cpp.
References gatherDimension().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 378 of file Spmdization.cpp.
References gatherDimension(), and shardDimension().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 45 of file Spmdization.cpp.
Referenced by handlePartialAxesDuringResharding().
int64_t mlir::mesh::collectiveProcessGroupSize | ( | MeshAxesRange && | meshAxes, |
MeshOp | mesh | ||
) |
Definition at line 95 of file MeshOps.h.
References collectiveProcessGroupSize().
int64_t mlir::mesh::collectiveProcessGroupSize | ( | MeshAxesRange && | meshAxes, |
MeshShapeRange && | meshShape | ||
) |
Definition at line 79 of file MeshOps.h.
Referenced by collectiveProcessGroupSize(), shardShape(), sliceResultType(), verifyAllToAllOperandAndResultShape(), verifyGatherOperandAndResultShape(), and verifyScatterOrSliceOperandAndResultShape().
TypedValue< IndexType > mlir::mesh::createCollectiveProcessGroupSize | ( | MeshOp | mesh, |
ArrayRef< MeshAxis > | axes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 201 of file Transforms.cpp.
References mlir::ImplicitLocOpBuilder::create(), mlir::arith::createProduct(), mlir::Builder::getIndexType(), and mlir::ImplicitLocOpBuilder::getLoc().
TypedValue< IndexType > mlir::mesh::createProcessLinearIndex | ( | StringRef | mesh, |
ArrayRef< MeshAxis > | meshAxes, | ||
ImplicitLocOpBuilder & | builder | ||
) |
Definition at line 210 of file Transforms.cpp.
References mlir::ImplicitLocOpBuilder::create(), and mlir::affine::linearizeIndex().
|
static |
Definition at line 307 of file Spmdization.cpp.
Referenced by tryMoveLastSplitAxisInResharding().
|
static |
Definition at line 154 of file Spmdization.cpp.
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 204 of file Spmdization.cpp.
Referenced by tryUnsplitLastAxisInResharding().
|
inline |
Definition at line 110 of file MeshOps.h.
Referenced by allGatherResultShapeInUnsplitLastAxis(), and allToAllResultShapeInMoveLastAxis().
mesh::MeshOp mlir::mesh::getMesh | ( | Op | op, |
SymbolTableCollection & | symbolTableCollection | ||
) |
|
inline |
Definition at line 57 of file MeshOps.h.
Referenced by reshard(), and shardedBlockArgumentTypes().
|
inline |
ShardingArray mlir::mesh::getMeshAxisAssignmentForLoopIterators | ( | ArrayRef< MeshShardingAttr > | operandShardings, |
ArrayRef< MeshShardingAttr > | resultShardings, | ||
ArrayRef< utils::IteratorType > | loopIteratorTypes, | ||
ArrayRef< AffineMap > | indexingMaps | ||
) |
Definition at line 582 of file ShardingInterface.cpp.
References updateMeshAxisAssignmentForLoopIterators().
FailureOr< std::pair< bool, MeshShardingAttr > > mlir::mesh::getMeshShardingAttr | ( | OpOperand & | opOperand | ) |
Definition at line 143 of file ShardingInterface.cpp.
References mlir::failure(), mlir::IROperand< DerivedT, IRValueT >::get(), and mlir::Value::getDefiningOp().
Definition at line 99 of file ShardingInterface.cpp.
References mlir::failure(), mlir::Value::getUsers(), and mlir::Value::hasOneUse().
Referenced by addShardOp().
|
static |
Definition at line 578 of file Spmdization.cpp.
SmallVector< MeshAxis > mlir::mesh::getReductionMeshAxes | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
ArrayRef< SmallVector< MeshAxis >> | meshAxisAssignmentForLoopIterators | ||
) |
Definition at line 636 of file ShardingInterface.cpp.
|
static |
Definition at line 598 of file Spmdization.cpp.
|
static |
Definition at line 59 of file Spmdization.cpp.
References arePartialAxesCompatible(), mlir::OpBuilder::create(), mlir::get(), mlir::Builder::getContext(), and mlir::OpBuilder::setInsertionPointAfterValue().
Referenced by reshardOn1DMesh().
bool mlir::mesh::isAtLeastOneReductionIteratorSharded | ( | ArrayRef< utils::IteratorType > | loopIteratorTypes, |
ArrayRef< SmallVector< MeshAxis >> | meshAxisAssignmentForLoopIterators | ||
) |
Definition at line 623 of file ShardingInterface.cpp.
|
inline |
Definition at line 53 of file MeshOps.h.
Referenced by isValueCompatibleWithFullReplicationSharding().
|
inline |
Definition at line 42 of file MeshOps.h.
Referenced by addShardOp().
|
static |
Definition at line 391 of file Spmdization.cpp.
References allToAllResultShapeInMoveLastAxis(), mlir::ImplicitLocOpBuilder::create(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInMoveLastAxis().
Referenced by tryMoveLastSplitAxisInResharding().
void mlir::mesh::populateAllOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 189 of file Transforms.cpp.
References populateAllSliceOpLoweringPatterns(), and populateProcessMultiIndexOpLoweringPatterns().
void mlir::mesh::populateAllReduceEndomorphismSimplificationPatterns | ( | RewritePatternSet & | patterns, |
ReductionKind | reduction | ||
) |
Definition at line 40 of file Simplifications.h.
void mlir::mesh::populateAllSliceOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 177 of file Transforms.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateAllOpLoweringPatterns().
void mlir::mesh::populateFoldingPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 117 of file Simplifications.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateSimplificationPatterns().
void mlir::mesh::populateProcessMultiIndexOpLoweringPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 167 of file Transforms.cpp.
References mlir::RewritePatternSet::add(), and mlir::RewritePatternSet::getContext().
Referenced by populateAllOpLoweringPatterns().
void mlir::mesh::populateSimplificationPatterns | ( | RewritePatternSet & | patterns, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 26 of file Simplifications.cpp.
References populateFoldingPatterns().
void mlir::mesh::registerAllOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 195 of file Transforms.cpp.
References registerAllSliceOpLoweringDialects(), and registerProcessMultiIndexOpLoweringDialects().
void mlir::mesh::registerAllSliceOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 183 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
void mlir::mesh::registerProcessMultiIndexOpLoweringDialects | ( | DialectRegistry & | registry | ) |
Definition at line 173 of file Transforms.cpp.
References mlir::DialectRegistry::insert().
Referenced by registerAllOpLoweringDialects().
void mlir::mesh::removeTrailingEmptySubArray | ( | SmallVector< SmallVector< T >> & | array | ) |
Definition at line 47 of file MeshOps.h.
Referenced by addShardOp().
TypedValue<ShapedType> mlir::mesh::reshard | ( | ImplicitLocOpBuilder & | builder, |
MeshOp | mesh, | ||
MeshShardingAttr | sourceSharding, | ||
MeshShardingAttr | targetSharding, | ||
TypedValue< ShapedType > | sourceUnshardedValue, | ||
TypedValue< ShapedType > | sourceShard | ||
) |
Definition at line 481 of file Spmdization.cpp.
References reshardOn1DMesh().
Referenced by spmdizeOperation().
TypedValue< ShapedType > mlir::mesh::reshard | ( | OpBuilder & | builder, |
MeshOp | mesh, | ||
ShardOp | source, | ||
ShardOp | target, | ||
TypedValue< ShapedType > | sourceShardValue | ||
) |
Definition at line 493 of file Spmdization.cpp.
Referenced by reshard().
TypedValue< ShapedType > mlir::mesh::reshard | ( | OpBuilder & | builder, |
ShardOp | source, | ||
ShardOp | target, | ||
TypedValue< ShapedType > | sourceShardValue, | ||
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 505 of file Spmdization.cpp.
void mlir::mesh::reshardingRegisterDependentDialects | ( | DialectRegistry & | registry | ) |
Definition at line 514 of file Spmdization.cpp.
References mlir::DialectRegistry::insert().
|
static |
Definition at line 438 of file Spmdization.cpp.
References mlir::Builder::getType(), handlePartialAxesDuringResharding(), shardShapedType(), tryMoveLastSplitAxisInResharding(), trySplitLastAxisInResharding(), and tryUnsplitLastAxisInResharding().
Referenced by reshard().
|
inline |
Definition at line 101 of file MeshOps.h.
References mlir::ceilDiv().
Referenced by allToAllResultShapeInMoveLastAxis(), and shardShape().
SmallVector<Type> mlir::mesh::shardedBlockArgumentTypes | ( | Block & | block, |
SymbolTableCollection & | symbolTableCollection | ||
) |
Definition at line 527 of file Spmdization.cpp.
References mlir::Block::getArguments(), getMesh(), mlir::Operation::getUsers(), and shardShapedType().
Referenced by spmdizeBlock().
ShapedType mlir::mesh::shardShapedType | ( | ShapedType | shape, |
MeshOp | mesh, | ||
MeshShardingAttr | sharding | ||
) |
Definition at line 162 of file MeshOps.cpp.
References shardShape().
Referenced by moveLastSplitAxisInResharding(), reshardOn1DMesh(), shardedBlockArgumentTypes(), shardType(), and unsplitLastAxisInResharding().
Definition at line 171 of file MeshOps.cpp.
References shardShapedType().
|
static |
Definition at line 133 of file Spmdization.cpp.
References mlir::ImplicitLocOpBuilder::create(), mlir::Builder::getContext(), and targetShardingInSplitLastAxis().
Referenced by trySplitLastAxisInResharding().
|
static |
Definition at line 664 of file Spmdization.cpp.
References mlir::OpBuilder::createBlock(), mlir::Block::getArguments(), mlir::Block::getParent(), mlir::IRMapping::map(), mlir::OpBuilder::setInsertionPointToEnd(), and shardedBlockArgumentTypes().
void mlir::mesh::spmdizeFullyReplicatedOperation | ( | Operation & | op, |
ArrayRef< Value > | spmdizedOperands, | ||
ArrayRef< MeshShardingAttr > | operandShardings, | ||
ArrayRef< MeshShardingAttr > | resultShardings, | ||
IRMapping & | spmdizationMap, | ||
SymbolTableCollection & | symbolTable, | ||
OpBuilder & | builder | ||
) |
Definition at line 553 of file ShardingInterface.cpp.
|
static |
Definition at line 691 of file Spmdization.cpp.
|
static |
Definition at line 549 of file Spmdization.cpp.
|
static |
Definition at line 644 of file Spmdization.cpp.
|
static |
Definition at line 618 of file Spmdization.cpp.
References mlir::IRMapping::contains(), mlir::IRMapping::lookup(), mlir::IRMapping::map(), reshard(), and mlir::success().
void mlir::mesh::spmdizeTriviallyShardableOperation | ( | Operation & | op, |
ArrayRef< Value > | spmdizedOperands, | ||
ArrayRef< MeshShardingAttr > | operandShardings, | ||
ArrayRef< MeshShardingAttr > | resultShardings, | ||
IRMapping & | spmdizationMap, | ||
SymbolTableCollection & | symbolTable, | ||
OpBuilder & | builder | ||
) |
Definition at line 649 of file ShardingInterface.cpp.
|
static |
Definition at line 349 of file Spmdization.cpp.
References mlir::detail::DenseArrayAttrImpl< T >::get(), and mlir::get().
Referenced by moveLastSplitAxisInResharding().
|
static |
Definition at line 111 of file Spmdization.cpp.
References mlir::detail::DenseArrayAttrImpl< T >::get(), and mlir::get().
Referenced by splitLastAxisInResharding().
|
static |
Definition at line 233 of file Spmdization.cpp.
References mlir::detail::DenseArrayAttrImpl< T >::get(), and mlir::get().
Referenced by unsplitLastAxisInResharding().
|
static |
Definition at line 418 of file Spmdization.cpp.
References detectMoveLastSplitAxisInResharding(), and moveLastSplitAxisInResharding().
Referenced by reshardOn1DMesh().
|
static |
Definition at line 186 of file Spmdization.cpp.
References detectSplitLastAxisInResharding(), and splitLastAxisInResharding().
Referenced by reshardOn1DMesh().
|
static |
Definition at line 285 of file Spmdization.cpp.
References detectUnsplitLastAxisInResharding(), and unsplitLastAxisInResharding().
Referenced by reshardOn1DMesh().
|
static |
Definition at line 260 of file Spmdization.cpp.
References allGatherResultShapeInUnsplitLastAxis(), mlir::ImplicitLocOpBuilder::create(), mlir::get(), mlir::Builder::getContext(), mlir::OpBuilder::setInsertionPointAfterValue(), shardShapedType(), and targetShardingInUnsplitLastAxis().
Referenced by tryUnsplitLastAxisInResharding().