32 #include "llvm/ADT/APInt.h"
33 #include "llvm/ADT/DenseSet.h"
34 #include "llvm/ADT/STLExtras.h"
35 #include "llvm/ADT/SmallVector.h"
36 #include "llvm/Support/Casting.h"
40 #include <type_traits>
44 template <
typename SourceAxes,
typename TargetAxes>
46 const TargetAxes &targetAxes) {
47 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
48 return sourceAxes.contains(targetAxis);
63 if (sourceSharding.getPartialAxes().empty() &&
64 targetSharding.getPartialAxes().empty()) {
65 return {sourceShard, sourceSharding};
67 assert(targetSharding.getPartialAxes().empty() ||
68 (!sourceSharding.getPartialAxes().empty() &&
69 sourceSharding.getPartialType() == targetSharding.getPartialType()));
70 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
71 using AxisSet = llvm::SmallDenseSet<Axis>;
72 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
73 sourceSharding.getPartialAxes().end());
74 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
75 targetSharding.getPartialAxes().end());
77 targetShardingPartialAxesSet));
79 llvm::copy_if(sourceShardingPartialAxesSet,
80 std::back_inserter(allReduceMeshAxes),
81 [&targetShardingPartialAxesSet](Axis a) {
82 return !targetShardingPartialAxesSet.contains(a);
84 if (allReduceMeshAxes.empty()) {
85 return {sourceShard, sourceSharding};
91 .
create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92 sourceSharding.getMesh().getLeafReference(),
93 allReduceMeshAxes, sourceShard,
94 sourceSharding.getPartialType())
98 llvm::copy_if(sourceShardingPartialAxesSet,
99 std::back_inserter(allReduceMeshAxes),
100 [&targetShardingPartialAxesSet](Axis a) {
101 return targetShardingPartialAxesSet.contains(a);
105 sourceSharding.getSplitAxes(), remainingPartialAxes,
106 sourceSharding.getPartialType());
107 return {resultValue, resultSharding};
112 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
114 llvm::to_vector(sourceSharding.getSplitAxes());
115 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
119 auto targetSplitAxes =
120 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
121 targetSplitAxes.push_back(splitMeshAxis);
122 targetShardingSplitAxes[splitTensorAxis] =
125 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
126 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
136 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
139 .
create<AllSliceOp>(sourceShard, mesh,
144 builder.
getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
145 return {targetShard, targetSharding};
153 static std::optional<std::tuple<int64_t, MeshAxis>>
156 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
158 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
159 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
160 targetSharding.getSplitAxes()[tensorAxis].size()) {
164 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
166 targetSharding.getSplitAxes()[tensorAxis]
169 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
174 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
178 return std::make_tuple(
180 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
192 auto [tensorAxis, meshAxis] = detectRes.value();
194 tensorAxis, meshAxis);
203 static std::optional<std::tuple<int64_t, MeshAxis>>
206 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
208 if (targetSharding.getSplitAxes().size() > tensorAxis) {
209 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
210 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
214 sourceSharding.getSplitAxes()[tensorAxis]
217 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
219 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
222 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
225 return std::make_tuple(
227 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
235 int64_t splitTensorAxis) {
237 llvm::to_vector(sourceSharding.getSplitAxes());
238 assert(
static_cast<int64_t
>(targetShardingSplitAxes.size()) >
240 auto targetSplitAxes =
241 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
243 targetSplitAxes.pop_back();
244 targetShardingSplitAxes[splitTensorAxis] =
247 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
248 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
252 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
254 targetShape[splitTensorAxis] =
256 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
262 ShapedType sourceUnshardedShape,
264 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
271 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
272 Value allGatherResult = builder.
create<AllGatherOp>(
274 allGatherResultShape.getElementType()),
276 APInt(64, splitTensorAxis));
277 ShapedType targetShape =
280 builder.
create<tensor::CastOp>(targetShape, allGatherResult).getResult());
281 return {targetShard, targetSharding};
288 ShapedType sourceUnshardedShape,
292 auto [tensorAxis, meshAxis] = detectRes.value();
294 sourceUnshardedShape, sourceShard, mesh,
295 tensorAxis, meshAxis);
306 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
309 for (
size_t sourceTensorAxis = 0;
310 sourceTensorAxis < sourceSharding.getSplitAxes().size();
311 ++sourceTensorAxis) {
312 for (
size_t targetTensorAxis = 0;
313 targetTensorAxis < targetSharding.getSplitAxes().size();
314 ++targetTensorAxis) {
315 if (sourceTensorAxis == targetTensorAxis)
317 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
318 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
319 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
320 targetSharding.getSplitAxes()[targetTensorAxis]
325 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
328 sourceSharding.getSplitAxes()[sourceTensorAxis]
332 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
335 targetSharding.getSplitAxes()[targetTensorAxis]
340 return std::make_tuple(
341 sourceTensorAxis, targetTensorAxis,
342 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
350 int64_t sourceTensorAxis,
351 int64_t targetTensorAxis) {
353 llvm::to_vector(sourceSharding.getSplitAxes());
354 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
359 auto sourceSplitAxes =
360 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
361 assert(!sourceSplitAxes.empty());
362 auto meshAxis = sourceSplitAxes.back();
363 sourceSplitAxes.pop_back();
364 targetShardingSplitAxes[sourceTensorAxis] =
367 auto targetSplitAxes =
368 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
369 targetSplitAxes.push_back(meshAxis);
370 targetShardingSplitAxes[targetTensorAxis] =
374 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
375 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
380 int64_t sourceTensorAxis,
381 int64_t targetTensorAxis) {
383 targetShape[sourceTensorAxis] =
385 targetShape[targetTensorAxis] =
387 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
393 ShapedType sourceUnshardedShape,
395 int64_t sourceTensorAxis,
396 int64_t targetTensorAxis,
MeshAxis meshAxis) {
401 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
403 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
407 allToAllResultShape.getElementType()),
409 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
410 ShapedType targetShape =
413 builder.
create<tensor::CastOp>(targetShape, allToAllResult).getResult());
414 return {targetShard, targetSharding};
421 ShapedType sourceUnshardedShape,
425 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
427 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
428 sourceTensorAxis, targetTensorAxis, meshAxis);
443 assert(sourceShard.getType() ==
444 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
445 [[maybe_unused]] ShapedType targetShardType =
447 assert(sourceShard.getType().getRank() == targetShardType.getRank());
448 assert(mesh.getRank() == 1 &&
"Only 1D meshes are currently supported.");
450 auto [reducedSourceShard, reducedSourceSharding] =
454 if (reducedSourceSharding == targetSharding) {
455 return reducedSourceShard;
461 builder, mesh, reducedSourceSharding, targetSharding,
462 sourceUnshardedValue.
getType(), reducedSourceShard)) {
463 std::tie(targetShard, actualTargetSharding) = tryRes.value();
465 builder, mesh, reducedSourceSharding, targetSharding,
466 reducedSourceShard)) {
467 std::tie(targetShard, actualTargetSharding) = tryRes.value();
469 builder, mesh, reducedSourceSharding, targetSharding,
470 sourceUnshardedValue.
getType(), reducedSourceShard)) {
471 std::tie(targetShard, actualTargetSharding) = tryRes.value();
473 assert(
false &&
"Did not find any pattern to apply.");
476 assert(actualTargetSharding == targetSharding);
477 assert(targetShard.getType() == targetShardType);
490 sourceUnshardedValue, sourceShard);
496 assert(!source.getAnnotateForUsers());
497 assert(target.getAnnotateForUsers());
498 assert(source.getResult() == target.getOperand());
501 implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
510 assert(srcMesh && srcMesh ==
getMesh(target, symbolTableCollection));
511 return reshard(builder, srcMesh, source, target, sourceShardValue);
515 registry.
insert<mesh::MeshDialect, tensor::TensorDialect>();
518 #define GEN_PASS_DEF_SPMDIZATION
519 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
533 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
534 if (!rankedTensorArg) {
535 return arg.getType();
538 assert(rankedTensorArg.hasOneUse());
540 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
544 shardOp.getShardAttr()));
554 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
555 if (!shardingInterface) {
559 resultShardings, spmdizationMap,
560 symbolTableCollection, builder);
562 if (
failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
563 resultShardings, spmdizationMap,
564 symbolTableCollection, builder))) {
569 assert(llvm::all_of(op.getResults(), [&spmdizationMap](
OpResult result) {
570 return spmdizationMap.contains(result);
580 res.reserve(op.getNumOperands());
581 llvm::transform(op.getOperands(), std::back_inserter(res), [](
Value operand) {
582 TypedValue<RankedTensorType> rankedTensor =
583 dyn_cast<TypedValue<RankedTensorType>>(operand);
585 return MeshShardingAttr();
590 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
591 return shardOp.getShard();
600 res.reserve(op.getNumResults());
601 llvm::transform(op.getResults(), std::back_inserter(res),
603 TypedValue<RankedTensorType> rankedTensor =
604 dyn_cast<TypedValue<RankedTensorType>>(result);
606 return MeshShardingAttr();
611 ShardOp shardOp = llvm::cast<ShardOp>(userOp);
612 return shardOp.getShard();
621 Value targetSpmdValue;
626 dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
628 targetSpmdValue = spmdizationMap.
lookup(shardOp.getOperand());
631 assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers());
633 spmdizationMap.
lookup(srcShardOp.getOperand()));
634 targetSpmdValue =
reshard(builder, srcShardOp, shardOp, srcSpmdValue,
635 symbolTableCollection);
638 assert(!spmdizationMap.
contains(shardOp.getResult()));
639 spmdizationMap.
map(shardOp.getResult(), targetSpmdValue);
647 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
654 llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
655 [&spmdizationMap](
Value operand) {
656 assert(spmdizationMap.contains(operand));
657 return spmdizationMap.lookup(operand);
661 symbolTableCollection, builder);
668 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
673 for (
auto [unshardedBlockArg, spmdizedBlockArg] :
675 spmdizationMap.
map(unshardedBlockArg, spmdizedBlockArg);
698 llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
699 [](
Block &b) { return &b; });
701 for (
Block *block : originalBlocks) {
708 for (
Block *block : originalBlocks) {
715 for (
Block &block : op.getFunctionBody()) {
721 returnOp = &block.back();
727 op.getFunctionBody().front().getArgumentTypes(),
735 struct Spmdization :
public impl::SpmdizationBase<Spmdization> {
736 void runOnOperation()
override {
740 symbolTableCollection))) {
741 return signalPassFailure();
745 void getDependentDialects(DialectRegistry ®istry)
const override {
747 registry.insert<mesh::MeshDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
user_range getUsers()
Returns a range of all users.
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
mesh::MeshShardingAttr MeshShardingAttr
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > handlePartialAxesDuringResharding(OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
static SmallVector< MeshShardingAttr > getOperandShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
static MeshShardingAttr targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static SmallVector< MeshShardingAttr > getResultShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshShardingAttr targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static MeshShardingAttr targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This trait indicates that a terminator operation is "return-like".