31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
39 #include <type_traits>
43 template <
typename SourceAxes,
typename TargetAxes>
45 const TargetAxes &targetAxes) {
46 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
47 return sourceAxes.contains(targetAxis);
62 if (sourceSharding.getPartialAxes().empty() &&
63 targetSharding.getPartialAxes().empty()) {
64 return {sourceShard, sourceSharding};
66 assert(targetSharding.getPartialAxes().empty() ||
67 (!sourceSharding.getPartialAxes().empty() &&
68 sourceSharding.getPartialType() == targetSharding.getPartialType()));
69 using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
70 using AxisSet = llvm::SmallDenseSet<Axis>;
71 AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
72 sourceSharding.getPartialAxes().end());
73 AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
74 targetSharding.getPartialAxes().end());
76 targetShardingPartialAxesSet));
78 llvm::copy_if(sourceShardingPartialAxesSet,
79 std::back_inserter(allReduceMeshAxes),
80 [&targetShardingPartialAxesSet](Axis a) {
81 return !targetShardingPartialAxesSet.contains(a);
83 if (allReduceMeshAxes.empty()) {
84 return {sourceShard, sourceSharding};
90 .
create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
91 sourceSharding.getMesh().getLeafReference(),
92 allReduceMeshAxes, sourceShard,
93 sourceSharding.getPartialType())
97 llvm::copy_if(sourceShardingPartialAxesSet,
98 std::back_inserter(allReduceMeshAxes),
99 [&targetShardingPartialAxesSet](Axis a) {
100 return targetShardingPartialAxesSet.contains(a);
104 sourceSharding.getSplitAxes(), remainingPartialAxes,
105 sourceSharding.getPartialType());
106 return {resultValue, resultSharding};
111 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
113 llvm::to_vector(sourceSharding.getSplitAxes());
114 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
118 auto targetSplitAxes =
119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120 targetSplitAxes.push_back(splitMeshAxis);
121 targetShardingSplitAxes[splitTensorAxis] =
124 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
125 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
135 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
138 .
create<AllSliceOp>(sourceShard, mesh,
143 builder.
getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144 return {targetShard, targetSharding};
152 static std::optional<std::tuple<int64_t, MeshAxis>>
155 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
157 if (sourceSharding.getSplitAxes().size() > tensorAxis) {
158 if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
159 targetSharding.getSplitAxes()[tensorAxis].size()) {
163 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
165 targetSharding.getSplitAxes()[tensorAxis]
168 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
173 if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
177 return std::make_tuple(
179 targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
191 auto [tensorAxis, meshAxis] = detectRes.value();
193 tensorAxis, meshAxis);
202 static std::optional<std::tuple<int64_t, MeshAxis>>
205 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
207 if (targetSharding.getSplitAxes().size() > tensorAxis) {
208 if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
209 targetSharding.getSplitAxes()[tensorAxis].size() + 1)
213 sourceSharding.getSplitAxes()[tensorAxis]
216 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
218 targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
221 if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
224 return std::make_tuple(
226 sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
234 int64_t splitTensorAxis) {
236 llvm::to_vector(sourceSharding.getSplitAxes());
237 assert(
static_cast<int64_t
>(targetShardingSplitAxes.size()) >
239 auto targetSplitAxes =
240 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
242 targetSplitAxes.pop_back();
243 targetShardingSplitAxes[splitTensorAxis] =
246 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
247 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
251 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
253 targetShape[splitTensorAxis] =
255 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
261 ShapedType sourceUnshardedShape,
263 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
270 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
271 Value allGatherResult = builder.
create<AllGatherOp>(
273 allGatherResultShape.getElementType()),
275 APInt(64, splitTensorAxis));
276 ShapedType targetShape =
279 builder.
create<tensor::CastOp>(targetShape, allGatherResult).getResult());
280 return {targetShard, targetSharding};
287 ShapedType sourceUnshardedShape,
291 auto [tensorAxis, meshAxis] = detectRes.value();
293 sourceUnshardedShape, sourceShard, mesh,
294 tensorAxis, meshAxis);
305 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
308 for (
size_t sourceTensorAxis = 0;
309 sourceTensorAxis < sourceSharding.getSplitAxes().size();
310 ++sourceTensorAxis) {
311 for (
size_t targetTensorAxis = 0;
312 targetTensorAxis < targetSharding.getSplitAxes().size();
313 ++targetTensorAxis) {
314 if (sourceTensorAxis == targetTensorAxis)
316 if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
317 targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
318 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
319 targetSharding.getSplitAxes()[targetTensorAxis]
324 llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
327 sourceSharding.getSplitAxes()[sourceTensorAxis]
331 llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
334 targetSharding.getSplitAxes()[targetTensorAxis]
339 return std::make_tuple(
340 sourceTensorAxis, targetTensorAxis,
341 sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
349 int64_t sourceTensorAxis,
350 int64_t targetTensorAxis) {
352 llvm::to_vector(sourceSharding.getSplitAxes());
353 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
358 auto sourceSplitAxes =
359 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
360 assert(!sourceSplitAxes.empty());
361 auto meshAxis = sourceSplitAxes.back();
362 sourceSplitAxes.pop_back();
363 targetShardingSplitAxes[sourceTensorAxis] =
366 auto targetSplitAxes =
367 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
368 targetSplitAxes.push_back(meshAxis);
369 targetShardingSplitAxes[targetTensorAxis] =
373 ctx, sourceSharding.getMesh(), targetShardingSplitAxes,
374 sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
379 int64_t sourceTensorAxis,
380 int64_t targetTensorAxis) {
382 targetShape[sourceTensorAxis] =
384 targetShape[targetTensorAxis] =
386 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
392 ShapedType sourceUnshardedShape,
394 int64_t sourceTensorAxis,
395 int64_t targetTensorAxis,
MeshAxis meshAxis) {
400 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
402 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
406 allToAllResultShape.getElementType()),
408 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
409 ShapedType targetShape =
412 builder.
create<tensor::CastOp>(targetShape, allToAllResult).getResult());
413 return {targetShard, targetSharding};
420 ShapedType sourceUnshardedShape,
424 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
426 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
427 sourceTensorAxis, targetTensorAxis, meshAxis);
442 assert(sourceShard.getType() ==
443 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
444 [[maybe_unused]] ShapedType targetShardType =
446 assert(sourceShard.getType().getRank() == targetShardType.getRank());
447 assert(mesh.getRank() == 1 &&
"Only 1D meshes are currently supported.");
449 auto [reducedSourceShard, reducedSourceSharding] =
453 if (reducedSourceSharding == targetSharding) {
454 return reducedSourceShard;
460 builder, mesh, reducedSourceSharding, targetSharding,
461 sourceUnshardedValue.
getType(), reducedSourceShard)) {
462 std::tie(targetShard, actualTargetSharding) = tryRes.value();
464 builder, mesh, reducedSourceSharding, targetSharding,
465 reducedSourceShard)) {
466 std::tie(targetShard, actualTargetSharding) = tryRes.value();
468 builder, mesh, reducedSourceSharding, targetSharding,
469 sourceUnshardedValue.
getType(), reducedSourceShard)) {
470 std::tie(targetShard, actualTargetSharding) = tryRes.value();
472 assert(
false &&
"Did not find any pattern to apply.");
475 assert(actualTargetSharding == targetSharding);
476 assert(targetShard.getType() == targetShardType);
489 sourceUnshardedValue, sourceShard);
495 assert(source.getResult() == target.getOperand());
498 implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
507 assert(srcMesh && srcMesh ==
getMesh(target, symbolTableCollection));
508 return reshard(builder, srcMesh, source, target, sourceShardValue);
512 registry.
insert<mesh::MeshDialect, tensor::TensorDialect>();
515 #define GEN_PASS_DEF_SPMDIZATION
516 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
530 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
531 if (!rankedTensorArg) {
532 return arg.getType();
535 assert(rankedTensorArg.hasOneUse());
537 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
541 shardOp.getShardAttr()));
551 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
552 if (!shardingInterface) {
556 resultShardings, spmdizationMap,
557 symbolTableCollection, builder);
559 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
560 resultShardings, spmdizationMap,
561 symbolTableCollection, builder))) {
566 assert(llvm::all_of(op.getResults(), [&spmdizationMap](
OpResult result) {
567 return spmdizationMap.contains(result);
577 res.reserve(op.getNumOperands());
578 llvm::transform(op.getOperands(), std::back_inserter(res), [](
Value operand) {
579 TypedValue<RankedTensorType> rankedTensor =
580 dyn_cast<TypedValue<RankedTensorType>>(operand);
582 return MeshShardingAttr();
587 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
588 return shardOp.getShard();
597 res.reserve(op.getNumResults());
598 llvm::transform(op.getResults(), std::back_inserter(res),
600 TypedValue<RankedTensorType> rankedTensor =
601 dyn_cast<TypedValue<RankedTensorType>>(result);
603 return MeshShardingAttr();
608 ShardOp shardOp = llvm::cast<ShardOp>(userOp);
609 return shardOp.getShard();
618 Value targetSpmdValue;
623 dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp());
625 targetSpmdValue = spmdizationMap.
lookup(shardOp.getOperand());
629 spmdizationMap.
lookup(srcShardOp.getOperand()));
630 targetSpmdValue =
reshard(builder, srcShardOp, shardOp, srcSpmdValue,
631 symbolTableCollection);
634 assert(!spmdizationMap.
contains(shardOp.getResult()));
635 spmdizationMap.
map(shardOp.getResult(), targetSpmdValue);
643 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
650 llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands),
651 [&spmdizationMap](
Value operand) {
652 assert(spmdizationMap.contains(operand));
653 return spmdizationMap.lookup(operand);
657 symbolTableCollection, builder);
664 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
669 for (
auto [unshardedBlockArg, spmdizedBlockArg] :
671 spmdizationMap.
map(unshardedBlockArg, spmdizedBlockArg);
694 llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks),
695 [](
Block &b) { return &b; });
697 for (
Block *block : originalBlocks) {
698 if (failed(
spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
704 for (
Block *block : originalBlocks) {
711 for (
Block &block : op.getFunctionBody()) {
717 returnOp = &block.back();
723 op.getFunctionBody().front().getArgumentTypes(),
731 struct Spmdization :
public impl::SpmdizationBase<Spmdization> {
732 void runOnOperation()
override {
736 symbolTableCollection))) {
737 return signalPassFailure();
741 void getDependentDialects(DialectRegistry ®istry)
const override {
743 registry.insert<mesh::MeshDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
user_range getUsers()
Returns a range of all users.
This class represents a collection of SymbolTables.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
mesh::MeshShardingAttr MeshShardingAttr
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > handlePartialAxesDuringResharding(OpBuilder &builder, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
static SmallVector< MeshShardingAttr > getOperandShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
static MeshShardingAttr targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::tuple< TypedValue< ShapedType >, MeshShardingAttr > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshShardingAttr sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static SmallVector< MeshShardingAttr > getResultShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshShardingAttr > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceShard)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshShardingAttr sourceSharding, MeshShardingAttr targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshShardingAttr targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static MeshShardingAttr targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".