31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
39 #include <type_traits>
43 template <
typename SourceAxes,
typename TargetAxes>
45 const TargetAxes &targetAxes) {
46 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
47 return sourceAxes.contains(targetAxis);
64 return {sourceShard, sourceSharding};
69 using Axis = std::decay_t<decltype(sourceSharding.
getPartialAxes().front())>;
70 using AxisSet = llvm::SmallDenseSet<Axis>;
71 AxisSet sourceShardingPartialAxesSet(sourceSharding.
getPartialAxes().begin(),
73 AxisSet targetShardingPartialAxesSet(targetSharding.
getPartialAxes().begin(),
76 targetShardingPartialAxesSet));
78 llvm::copy_if(sourceShardingPartialAxesSet,
79 std::back_inserter(allReduceMeshAxes),
80 [&targetShardingPartialAxesSet](Axis a) {
81 return !targetShardingPartialAxesSet.contains(a);
83 if (allReduceMeshAxes.empty()) {
84 return {sourceShard, sourceSharding};
90 .
create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92 allReduceMeshAxes, sourceShard,
97 llvm::copy_if(sourceShardingPartialAxesSet,
98 std::back_inserter(allReduceMeshAxes),
99 [&targetShardingPartialAxesSet](Axis a) {
100 return targetShardingPartialAxesSet.contains(a);
105 return {resultValue, resultSharding};
110 int64_t splitTensorAxis,
114 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
118 auto targetSplitAxes =
119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120 targetSplitAxes.push_back(splitMeshAxis);
121 targetShardingSplitAxes[splitTensorAxis] =
124 sourceSharding.
getMeshAttr(), targetShardingSplitAxes,
135 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
138 .
create<AllSliceOp>(sourceShard, mesh,
143 builder.
getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144 return {targetShard, targetSharding};
152 static std::optional<std::tuple<int64_t, MeshAxis>>
155 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.
getSplitAxes().size();
157 if (sourceSharding.
getSplitAxes().size() > tensorAxis) {
158 if (sourceSharding.
getSplitAxes()[tensorAxis].size() + 1 !=
168 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
173 if (targetSharding.
getSplitAxes()[tensorAxis].size() != 1) {
177 return std::make_tuple(
179 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
184 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
191 auto [tensorAxis, meshAxis] = detectRes.value();
193 tensorAxis, meshAxis);
202 static std::optional<std::tuple<int64_t, MeshAxis>>
205 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.
getSplitAxes().size();
207 if (targetSharding.
getSplitAxes().size() > tensorAxis) {
216 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
218 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef()))
221 if (sourceSharding.
getSplitAxes()[tensorAxis].size() != 1)
224 return std::make_tuple(
226 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
233 int64_t splitTensorAxis) {
236 assert(
static_cast<int64_t
>(targetShardingSplitAxes.size()) >
238 auto targetSplitAxes =
239 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
241 targetSplitAxes.pop_back();
242 targetShardingSplitAxes[splitTensorAxis] =
245 sourceSharding.
getMeshAttr(), targetShardingSplitAxes,
250 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
252 targetShape[splitTensorAxis] =
254 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
260 ShapedType sourceUnshardedShape,
262 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270 Value allGatherResult = builder.
create<AllGatherOp>(
272 allGatherResultShape.getElementType()),
274 APInt(64, splitTensorAxis));
275 ShapedType targetShape =
278 builder.
create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279 return {targetShard, targetSharding};
282 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
286 ShapedType sourceUnshardedShape,
290 auto [tensorAxis, meshAxis] = detectRes.value();
292 sourceUnshardedShape, sourceShard, mesh,
293 tensorAxis, meshAxis);
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
307 for (
size_t sourceTensorAxis = 0;
308 sourceTensorAxis < sourceSharding.
getSplitAxes().size();
309 ++sourceTensorAxis) {
310 for (
size_t targetTensorAxis = 0;
311 targetTensorAxis < targetSharding.
getSplitAxes().size();
312 ++targetTensorAxis) {
313 if (sourceTensorAxis == targetTensorAxis)
315 if (sourceSharding.
getSplitAxes()[sourceTensorAxis].empty() ||
316 targetSharding.
getSplitAxes()[targetTensorAxis].empty() ||
317 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
323 llvm::make_range(sourceSharding.
getSplitAxes()[sourceTensorAxis]
330 llvm::make_range(targetSharding.
getSplitAxes()[targetTensorAxis]
338 return std::make_tuple(
339 sourceTensorAxis, targetTensorAxis,
340 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back());
348 int64_t sourceTensorAxis,
349 int64_t targetTensorAxis) {
352 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
357 auto sourceSplitAxes =
358 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359 assert(!sourceSplitAxes.empty());
360 auto meshAxis = sourceSplitAxes.back();
361 sourceSplitAxes.pop_back();
362 targetShardingSplitAxes[sourceTensorAxis] =
365 auto targetSplitAxes =
366 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367 targetSplitAxes.push_back(meshAxis);
368 targetShardingSplitAxes[targetTensorAxis] =
372 sourceSharding.
getMeshAttr(), targetShardingSplitAxes,
378 int64_t sourceTensorAxis,
379 int64_t targetTensorAxis) {
381 targetShape[sourceTensorAxis] =
383 targetShape[targetTensorAxis] =
385 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
391 ShapedType sourceUnshardedShape,
393 int64_t sourceTensorAxis,
394 int64_t targetTensorAxis,
MeshAxis meshAxis) {
399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
405 allToAllResultShape.getElementType()),
407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408 ShapedType targetShape =
411 builder.
create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412 return {targetShard, targetSharding};
415 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
419 ShapedType sourceUnshardedShape,
423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426 sourceTensorAxis, targetTensorAxis, meshAxis);
436 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
440 ShapedType sourceUnshardedShape,
455 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456 assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457 !ShapedType::isDynamicShape(tgtHaloSizes) &&
458 sourceShard.getType().hasStaticShape()) &&
459 "dynamic shapes/halos are not supported yet for mesh-spmdization");
460 auto rank = sourceShard.getType().getRank();
463 strides(rank, 1), outShape(sourceShard.getType().getShape()),
464 coreShape(sourceShard.getType().getShape());
468 for (
auto i = 0u; i < rank; ++i) {
469 if (i < splitAxes.size() && !splitAxes[i].empty()) {
470 if (!srcHaloSizes.empty()) {
471 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472 srcCoreOffs[i] = srcHaloSizes[i * 2];
474 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
476 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
482 auto initVal = builder.
create<tensor::EmptyOp>(
483 sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484 auto core = builder.
create<tensor::ExtractSliceOp>(
485 sourceShard.getLoc(),
487 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488 auto initOprnd = builder.
create<tensor::InsertSliceOp>(
489 sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
493 auto updateHaloResult =
496 sourceShard.getLoc(),
498 sourceShard.getType().getElementType()),
499 sourceShard, initOprnd, mesh.getSymName(),
519 assert(sourceShard.getType() ==
520 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
521 [[maybe_unused]] ShapedType targetShardType =
523 assert(sourceShard.getType().getRank() == targetShardType.getRank());
524 assert(mesh.getRank() == 1 &&
"Only 1D meshes are currently supported.");
526 auto [reducedSourceShard, reducedSourceSharding] =
530 if (reducedSourceSharding == targetSharding) {
531 return reducedSourceShard;
536 if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
538 reducedSourceSharding.getStaticHaloSizes().empty() &&
541 builder, mesh, reducedSourceSharding, targetSharding,
542 sourceUnshardedValue.getType(), reducedSourceShard)) {
543 std::tie(targetShard, actualTargetSharding) = tryRes.value();
545 builder, mesh, reducedSourceSharding, targetSharding,
546 reducedSourceShard)) {
547 std::tie(targetShard, actualTargetSharding) = tryRes.value();
549 builder, mesh, reducedSourceSharding, targetSharding,
550 sourceUnshardedValue.getType(), reducedSourceShard)) {
551 std::tie(targetShard, actualTargetSharding) = tryRes.value();
554 assert(targetShard &&
"Did not find any pattern to apply.");
555 assert(actualTargetSharding == targetSharding);
556 assert(targetShard.getType() == targetShardType);
566 if (sourceSharding == targetSharding) {
573 builder, mesh, sourceSharding, targetSharding,
574 sourceUnshardedValue.getType(), sourceShard)) {
575 return std::get<0>(tryRes.value());
582 sourceUnshardedValue, sourceShard);
588 assert(source.getResult() == target.getSrc());
589 auto sourceSharding = source.getSharding();
590 auto targetSharding = target.getSharding();
592 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
602 assert(srcMesh && srcMesh ==
getMesh(target, symbolTableCollection));
603 return reshard(builder, srcMesh, source, target, sourceShardValue);
607 registry.
insert<mesh::MeshDialect, tensor::TensorDialect>();
610 #define GEN_PASS_DEF_SPMDIZATION
611 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
625 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
626 if (!rankedTensorArg) {
627 return arg.getType();
630 assert(rankedTensorArg.hasOneUse());
632 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
636 shardOp.getSharding()));
642 ArrayRef<Value> spmdizedOperands,
643 ArrayRef<MeshSharding> operandShardings,
644 ArrayRef<MeshSharding> resultShardings,
645 IRMapping &spmdizationMap,
646 SymbolTableCollection &symbolTable,
654 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
655 if (!shardingInterface) {
659 resultShardings, spmdizationMap,
660 symbolTableCollection, builder);
662 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
663 resultShardings, spmdizationMap,
664 symbolTableCollection, builder))) {
670 return spmdizationMap.contains(result);
679 std::vector<MeshSharding> res;
681 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
682 TypedValue<RankedTensorType> rankedTensor =
683 dyn_cast<TypedValue<RankedTensorType>>(operand);
685 return MeshSharding();
690 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
699 std::vector<MeshSharding> res;
701 llvm::transform(op.
getResults(), std::back_inserter(res),
703 TypedValue<RankedTensorType> rankedTensor =
704 dyn_cast<TypedValue<RankedTensorType>>(result);
706 return MeshSharding();
711 ShardOp shardOp = llvm::cast<ShardOp>(userOp);
721 Value targetSpmdValue;
726 dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
728 targetSpmdValue = spmdizationMap.
lookup(shardOp.getSrc());
732 cast<TypedValue<ShapedType>>(spmdizationMap.
lookup(srcShardOp));
733 targetSpmdValue =
reshard(builder, srcShardOp, shardOp, srcSpmdValue,
734 symbolTableCollection);
737 assert(!spmdizationMap.
contains(shardOp.getResult()));
738 spmdizationMap.
map(shardOp.getResult(), targetSpmdValue);
746 if (isa<ShardingOp>(op)) {
750 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
757 llvm::transform(op.
getOperands(), std::back_inserter(spmdizedOperands),
758 [&spmdizationMap](
Value operand) {
759 assert(spmdizationMap.contains(operand));
760 return spmdizationMap.lookup(operand);
764 symbolTableCollection, builder);
771 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
776 for (
auto [unshardedBlockArg, spmdizedBlockArg] :
778 spmdizationMap.
map(unshardedBlockArg, spmdizedBlockArg);
801 llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
802 [](
Block &b) { return &b; });
804 for (
Block *block : originalBlocks) {
805 if (failed(
spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
811 for (
Block *block : originalBlocks) {
818 for (
Block &block : op.getFunctionBody()) {
824 returnOp = &block.back();
830 op.getFunctionBody().front().getArgumentTypes(),
838 struct Spmdization :
public impl::SpmdizationBase<Spmdization> {
839 void runOnOperation()
override {
843 symbolTableCollection))) {
844 return signalPassFailure();
848 void getDependentDialects(DialectRegistry ®istry)
const override {
850 registry.insert<mesh::MeshDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
ReductionKind getPartialType() const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticHaloSizes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
mesh::MeshSharding MeshSharding
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::vector< MeshSharding > getResultShardings(Operation &op)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static std::vector< MeshSharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".