29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Support/Casting.h"
38 template <
typename SourceAxes,
typename TargetAxes>
40 const TargetAxes &targetAxes) {
41 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
42 return sourceAxes.contains(targetAxis);
48 int64_t splitTensorAxis,
52 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
56 auto targetSplitAxes =
57 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
58 targetSplitAxes.push_back(splitGridAxis);
59 targetShardingSplitAxes[splitTensorAxis] =
67 static std::tuple<TypedValue<ShapedType>,
Sharding>
71 int64_t splitTensorAxis,
GridAxis splitGridAxis) {
73 AllSliceOp::create(builder, sourceShard, grid,
77 builder.
getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
78 return {targetShard, targetSharding};
86 static std::optional<std::tuple<int64_t, GridAxis>>
89 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.
getSplitAxes().size();
92 if (sourceSharding.
getSplitAxes()[tensorAxis].size() + 1 !=
102 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
107 if (targetSharding.
getSplitAxes()[tensorAxis].size() != 1) {
111 return std::make_tuple(
113 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
118 static std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
124 auto [tensorAxis, gridAxis] = detectRes.value();
126 tensorAxis, gridAxis);
135 static std::optional<std::tuple<int64_t, GridAxis>>
138 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.
getSplitAxes().size();
140 if (targetSharding.
getSplitAxes().size() > tensorAxis) {
149 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
151 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef()))
154 if (sourceSharding.
getSplitAxes()[tensorAxis].size() != 1)
157 return std::make_tuple(
159 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
166 int64_t splitTensorAxis) {
169 assert(
static_cast<int64_t
>(targetShardingSplitAxes.size()) >
171 auto targetSplitAxes =
172 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
174 targetSplitAxes.pop_back();
175 targetShardingSplitAxes[splitTensorAxis] =
181 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
183 targetShape[splitTensorAxis] =
185 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
198 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
199 Value allGatherResult = AllGatherOp::create(
202 allGatherResultShape.getElementType()),
204 APInt(64, splitTensorAxis));
205 ShapedType targetShape =
208 tensor::CastOp::create(builder, targetShape, allGatherResult)
210 return {targetShard, targetSharding};
213 static std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
216 ShapedType sourceUnshardedShape,
220 auto [tensorAxis, gridAxis] = detectRes.value();
222 sourceUnshardedShape, sourceShard, grid,
223 tensorAxis, gridAxis);
234 static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
237 for (
size_t sourceTensorAxis = 0;
238 sourceTensorAxis < sourceSharding.
getSplitAxes().size();
239 ++sourceTensorAxis) {
240 for (
size_t targetTensorAxis = 0;
241 targetTensorAxis < targetSharding.
getSplitAxes().size();
242 ++targetTensorAxis) {
243 if (sourceTensorAxis == targetTensorAxis)
245 if (sourceSharding.
getSplitAxes()[sourceTensorAxis].empty() ||
246 targetSharding.
getSplitAxes()[targetTensorAxis].empty() ||
247 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
253 llvm::make_range(sourceSharding.
getSplitAxes()[sourceTensorAxis]
260 llvm::make_range(targetSharding.
getSplitAxes()[targetTensorAxis]
268 return std::make_tuple(
269 sourceTensorAxis, targetTensorAxis,
270 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back());
278 int64_t sourceTensorAxis,
279 int64_t targetTensorAxis) {
282 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
287 auto sourceSplitAxes =
288 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
289 assert(!sourceSplitAxes.empty());
290 auto gridAxis = sourceSplitAxes.back();
291 sourceSplitAxes.pop_back();
292 targetShardingSplitAxes[sourceTensorAxis] =
295 auto targetSplitAxes =
296 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
297 targetSplitAxes.push_back(gridAxis);
298 targetShardingSplitAxes[targetTensorAxis] =
306 int64_t sourceTensorAxis,
307 int64_t targetTensorAxis) {
309 targetShape[sourceTensorAxis] =
311 targetShape[targetTensorAxis] =
313 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
316 static std::tuple<TypedValue<ShapedType>,
Sharding>
319 ShapedType sourceUnshardedShape,
321 int64_t sourceTensorAxis,
322 int64_t targetTensorAxis,
GridAxis gridAxis) {
327 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
329 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
331 Value allToAllResult = AllToAllOp::create(
334 allToAllResultShape.getElementType()),
336 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
337 ShapedType targetShape =
340 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult());
341 return {targetShard, targetSharding};
344 static std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
348 ShapedType sourceUnshardedShape,
352 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
354 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
355 sourceTensorAxis, targetTensorAxis, gridAxis);
365 static std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
368 ShapedType sourceUnshardedShape,
381 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
382 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
383 ShapedType::isStaticShape(tgtHaloSizes) &&
384 sourceShard.getType().hasStaticShape()) &&
385 "dynamic shapes/halos are not supported yet for shard-partition");
386 auto rank = sourceShard.getType().getRank();
389 strides(rank, 1), outShape(sourceShard.getType().getShape()),
390 coreShape(sourceShard.getType().getShape());
394 for (
auto i = 0u; i < rank; ++i) {
395 if (i < splitAxes.size() && !splitAxes[i].empty()) {
396 if (!srcHaloSizes.empty()) {
397 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
398 srcCoreOffs[i] = srcHaloSizes[i * 2];
400 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
402 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
409 tensor::EmptyOp::create(builder, sourceShard.
getLoc(), outShape,
410 sourceShard.getType().getElementType());
411 auto core = tensor::ExtractSliceOp::create(
412 builder, sourceShard.
getLoc(),
414 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
415 auto initOprnd = tensor::InsertSliceOp::create(
416 builder, sourceShard.
getLoc(), core, initVal, noVals, noVals, noVals,
417 tgtCoreOffs, coreShape, strides);
420 auto updateHaloResult =
421 UpdateHaloOp::create(
422 builder, sourceShard.
getLoc(),
424 sourceShard.getType().getElementType()),
425 initOprnd, grid.getSymName(),
443 assert(sourceShard.getType() ==
444 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
445 [[maybe_unused]] ShapedType targetShardType =
447 assert(sourceShard.getType().getRank() == targetShardType.getRank());
448 assert(grid.getRank() == 1 &&
"Only 1D grides are currently supported.");
450 if (sourceSharding == targetSharding) {
461 builder, grid, sourceSharding, targetSharding,
462 sourceUnshardedValue.getType(), sourceShard)) {
463 std::tie(targetShard, actualTargetSharding) = tryRes.value();
464 }
else if (
auto tryRes =
466 targetSharding, sourceShard)) {
467 std::tie(targetShard, actualTargetSharding) = tryRes.value();
469 builder, grid, sourceSharding, targetSharding,
470 sourceUnshardedValue.getType(), sourceShard)) {
471 std::tie(targetShard, actualTargetSharding) = tryRes.value();
474 assert(targetShard &&
"Did not find any pattern to apply.");
475 assert(actualTargetSharding == targetSharding);
476 assert(targetShard.getType() == targetShardType);
493 builder, grid, sourceSharding, targetSharding,
494 sourceUnshardedValue.getType(), sourceShard)) {
495 return std::get<0>(tryRes.value());
502 sourceUnshardedValue, sourceShard);
508 assert(source.getResult() == target.getSrc());
509 auto sourceSharding = source.getSharding();
510 auto targetSharding = target.getSharding();
512 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
522 assert(srcGrid && srcGrid ==
getGrid(target, symbolTableCollection));
523 return reshard(builder, srcGrid, source, target, sourceShardValue);
527 registry.
insert<shard::ShardDialect, tensor::TensorDialect>();
530 #define GEN_PASS_DEF_PARTITION
531 #include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
545 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
546 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
547 return arg.getType();
550 assert(rankedTensorArg.hasOneUse());
552 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
556 shardOp.getSharding()));
567 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
568 if (!shardingInterface) {
572 resultShardings, partitionMap,
573 symbolTableCollection, builder);
575 if (
failed(shardingInterface.partition(
576 partitionedOperands, operandShardings, resultShardings,
577 partitionMap, symbolTableCollection, builder))) {
583 return partitionMap.contains(result);
592 std::vector<Sharding> res;
594 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
595 TypedValue<RankedTensorType> rankedTensor =
596 dyn_cast<TypedValue<RankedTensorType>>(operand);
597 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
603 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
604 return Sharding(shardOp.getSharding());
612 std::vector<Sharding> res;
616 if (!result.hasOneUse() || result.use_empty()) {
625 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
627 return Sharding(shardOp.getSharding());
629 if (rankedTensor.getType().getRank() == 0) {
634 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
635 return Sharding(sharding.getGridAttr());
648 Value targetPartitionValue;
652 ShardOp srcShardOp = shardOp.getSrc().
getDefiningOp<ShardOp>();
654 targetPartitionValue = partitionMap.
lookup(shardOp.getSrc());
658 cast<TypedValue<ShapedType>>(partitionMap.
lookup(srcShardOp));
659 targetPartitionValue =
reshard(builder, srcShardOp, shardOp,
660 srcPartitionValue, symbolTableCollection);
663 assert(!partitionMap.
contains(shardOp.getResult()));
664 partitionMap.
map(shardOp.getResult(), targetPartitionValue);
672 if (isa<ShardingOp>(op)) {
675 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
676 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
678 return op.
emitError(
"expected a shard op as source of get_sharding");
680 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
681 partitionMap.
map(op.
getResult(0), newSharding->getResult(0));
685 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
692 llvm::transform(op.
getOperands(), std::back_inserter(partitionedOperands),
693 [&partitionMap](
Value operand) {
694 assert(partitionMap.contains(operand));
695 return partitionMap.lookup(operand);
699 symbolTableCollection, builder);
708 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
713 for (
auto [unshardedBlockArg, partitionedBlockArg] :
715 partitionMap.
map(unshardedBlockArg, partitionedBlockArg);
738 for (
Block &b : op.getBlocks()) {
739 if (llvm::any_of(b.getOperations(),
740 [](
Operation &op) { return isa<ShardOp>(op); })) {
741 originalBlocks.push_back(&b);
745 for (
Block *block : originalBlocks) {
752 for (
Block *block : originalBlocks) {
759 for (
Block &block : op.getFunctionBody()) {
765 returnOp = &block.back();
771 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
780 struct Partition :
public impl::PartitionBase<Partition> {
781 void runOnOperation()
override {
785 symbolTableCollection))) {
786 return signalPassFailure();
790 void getDependentDialects(DialectRegistry ®istry)
const override {
792 registry.insert<shard::ShardDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< int64_t > getStaticHaloSizes() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
static LogicalResult partitionOperation(Operation &op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
bool isFullReplication(Sharding sharding)
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
static std::vector< Sharding > getResultShardings(Operation &op)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".