29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
38template <
typename SourceAxes,
typename TargetAxes>
40 const TargetAxes &targetAxes) {
41 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
42 return sourceAxes.contains(targetAxis);
52 while (
static_cast<int64_t>(targetShardingSplitAxes.size()) <=
56 auto targetSplitAxes =
57 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
58 targetSplitAxes.push_back(splitGridAxis);
59 targetShardingSplitAxes[splitTensorAxis] =
67static std::tuple<TypedValue<ShapedType>,
Sharding>
73 AllSliceOp::create(builder, sourceShard, grid,
77 builder.
getContext(), sourceSharding, splitTensorAxis, splitGridAxis);
78 return {targetShard, targetSharding};
86static std::optional<std::tuple<int64_t, GridAxis>>
89 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.
getSplitAxes().size();
92 if (sourceSharding.
getSplitAxes()[tensorAxis].size() + 1 !=
102 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
107 if (targetSharding.
getSplitAxes()[tensorAxis].size() != 1) {
111 return std::make_tuple(
113 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
118static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
124 auto [tensorAxis, gridAxis] = detectRes.value();
126 tensorAxis, gridAxis);
135static std::optional<std::tuple<int64_t, GridAxis>>
138 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.
getSplitAxes().size();
140 if (targetSharding.
getSplitAxes().size() > tensorAxis) {
149 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
151 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef()))
154 if (sourceSharding.
getSplitAxes()[tensorAxis].size() != 1)
157 return std::make_tuple(
159 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
169 assert(
static_cast<int64_t>(targetShardingSplitAxes.size()) >
171 auto targetSplitAxes =
172 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
174 targetSplitAxes.pop_back();
175 targetShardingSplitAxes[splitTensorAxis] =
181 ShapedType sourceShape,
int64_t splitCount,
int64_t splitTensorAxis) {
183 targetShape[splitTensorAxis] =
185 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
198 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
199 Value allGatherResult = AllGatherOp::create(
201 RankedTensorType::get(allGatherResultShape.getShape(),
202 allGatherResultShape.getElementType()),
204 APInt(64, splitTensorAxis));
205 ShapedType targetShape =
208 tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
209 return {targetShard, targetSharding};
212static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
215 ShapedType sourceUnshardedShape,
219 auto [tensorAxis, gridAxis] = detectRes.value();
221 sourceUnshardedShape, sourceShard, grid,
222 tensorAxis, gridAxis);
233static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
236 for (
size_t sourceTensorAxis = 0;
237 sourceTensorAxis < sourceSharding.
getSplitAxes().size();
238 ++sourceTensorAxis) {
239 for (
size_t targetTensorAxis = 0;
240 targetTensorAxis < targetSharding.
getSplitAxes().size();
241 ++targetTensorAxis) {
242 if (sourceTensorAxis == targetTensorAxis)
244 if (sourceSharding.
getSplitAxes()[sourceTensorAxis].empty() ||
245 targetSharding.
getSplitAxes()[targetTensorAxis].empty() ||
246 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
252 llvm::make_range(sourceSharding.
getSplitAxes()[sourceTensorAxis]
259 llvm::make_range(targetSharding.
getSplitAxes()[targetTensorAxis]
267 return std::make_tuple(
268 sourceTensorAxis, targetTensorAxis,
269 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back());
281 while (
static_cast<int64_t>(targetShardingSplitAxes.size()) <=
286 auto sourceSplitAxes =
287 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
288 assert(!sourceSplitAxes.empty());
289 auto gridAxis = sourceSplitAxes.back();
290 sourceSplitAxes.pop_back();
291 targetShardingSplitAxes[sourceTensorAxis] =
294 auto targetSplitAxes =
295 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
296 targetSplitAxes.push_back(gridAxis);
297 targetShardingSplitAxes[targetTensorAxis] =
308 targetShape[sourceTensorAxis] =
310 targetShape[targetTensorAxis] =
312 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
315static std::tuple<TypedValue<ShapedType>, Sharding>
318 ShapedType sourceUnshardedShape,
326 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
328 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
330 Value allToAllResult = AllToAllOp::create(
332 RankedTensorType::get(allToAllResultShape.getShape(),
333 allToAllResultShape.getElementType()),
335 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
336 ShapedType targetShape =
339 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
340 return {targetShard, targetSharding};
343static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
347 ShapedType sourceUnshardedShape,
351 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
353 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
354 sourceTensorAxis, targetTensorAxis, gridAxis);
364static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
367 ShapedType sourceUnshardedShape,
380 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
381 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
382 ShapedType::isStaticShape(tgtHaloSizes) &&
383 sourceShard.getType().hasStaticShape()) &&
384 "dynamic shapes/halos are not supported yet for shard-partition");
385 auto rank = sourceShard.getType().getRank();
388 strides(rank, 1), outShape(sourceShard.getType().getShape()),
389 coreShape(sourceShard.getType().getShape());
393 for (
auto i = 0u; i < rank; ++i) {
394 if (i < splitAxes.size() && !splitAxes[i].empty()) {
395 if (!srcHaloSizes.empty()) {
396 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
397 srcCoreOffs[i] = srcHaloSizes[i * 2];
399 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
401 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
408 tensor::EmptyOp::create(builder, sourceShard.
getLoc(), outShape,
409 sourceShard.getType().getElementType());
410 auto core = tensor::ExtractSliceOp::create(
411 builder, sourceShard.
getLoc(),
412 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
413 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
414 auto initOprnd = tensor::InsertSliceOp::create(
415 builder, sourceShard.
getLoc(), core, initVal, noVals, noVals, noVals,
416 tgtCoreOffs, coreShape, strides);
419 auto updateHaloResult =
420 UpdateHaloOp::create(
421 builder, sourceShard.
getLoc(),
422 RankedTensorType::get(outShape,
423 sourceShard.getType().getElementType()),
424 initOprnd, grid.getSymName(),
442 assert(sourceShard.getType() ==
443 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
444 [[maybe_unused]] ShapedType targetShardType =
446 assert(sourceShard.getType().getRank() == targetShardType.getRank());
447 assert(grid.getRank() == 1 &&
"Only 1D grides are currently supported.");
449 if (sourceSharding == targetSharding) {
460 builder, grid, sourceSharding, targetSharding,
461 sourceUnshardedValue.getType(), sourceShard)) {
462 std::tie(targetShard, actualTargetSharding) = tryRes.value();
463 }
else if (
auto tryRes =
465 targetSharding, sourceShard)) {
466 std::tie(targetShard, actualTargetSharding) = tryRes.value();
468 builder, grid, sourceSharding, targetSharding,
469 sourceUnshardedValue.getType(), sourceShard)) {
470 std::tie(targetShard, actualTargetSharding) = tryRes.value();
473 assert(targetShard &&
"Did not find any pattern to apply.");
474 assert(actualTargetSharding == targetSharding);
475 assert(targetShard.getType() == targetShardType);
492 builder, grid, sourceSharding, targetSharding,
493 sourceUnshardedValue.getType(), sourceShard)) {
494 return std::get<0>(tryRes.value());
501 sourceUnshardedValue, sourceShard);
507 assert(source.getResult() ==
target.getSrc());
508 auto sourceSharding = source.getSharding();
509 auto targetSharding =
target.getSharding();
511 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
512 source.getSrc(), sourceShardValue);
519 GridOp srcGrid =
getGrid(source, symbolTableCollection);
520 assert(srcGrid && srcGrid ==
getGrid(
target, symbolTableCollection));
521 return reshard(builder, srcGrid, source,
target, sourceShardValue);
525 registry.
insert<shard::ShardDialect, tensor::TensorDialect>();
528#define GEN_PASS_DEF_PARTITION
529#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
543 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
544 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
545 return arg.getType();
548 assert(rankedTensorArg.hasOneUse());
550 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
552 GridOp grid =
getGrid(shardOp, symbolTableCollection);
554 shardOp.getSharding()));
565 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
566 if (!shardingInterface) {
570 resultShardings, partitionMap,
571 symbolTableCollection, builder);
573 if (failed(shardingInterface.partition(
574 partitionedOperands, operandShardings, resultShardings,
575 partitionMap, symbolTableCollection, builder))) {
581 return partitionMap.contains(result);
590 std::vector<Sharding> res;
592 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
593 TypedValue<RankedTensorType> rankedTensor =
594 dyn_cast<TypedValue<RankedTensorType>>(operand);
595 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
601 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
602 return Sharding(shardOp.getSharding());
610 std::vector<Sharding> res;
614 if (!result.hasOneUse() || result.use_empty()) {
623 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
625 return Sharding(shardOp.getSharding());
627 if (rankedTensor.getType().getRank() == 0) {
632 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
633 return Sharding(sharding.getGridAttr());
646 Value targetPartitionValue;
650 ShardOp srcShardOp = shardOp.getSrc().
getDefiningOp<ShardOp>();
652 targetPartitionValue = partitionMap.
lookup(shardOp.getSrc());
656 cast<TypedValue<ShapedType>>(partitionMap.
lookup(srcShardOp));
657 targetPartitionValue =
reshard(builder, srcShardOp, shardOp,
658 srcPartitionValue, symbolTableCollection);
661 assert(!partitionMap.
contains(shardOp.getResult()));
662 partitionMap.
map(shardOp.getResult(), targetPartitionValue);
670 if (isa<ShardingOp>(op)) {
673 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
674 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
676 return op.
emitError(
"expected a shard op as source of get_sharding");
678 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
679 partitionMap.
map(op.
getResult(0), newSharding->getResult(0));
683 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
690 llvm::transform(op.
getOperands(), std::back_inserter(partitionedOperands),
691 [&partitionMap](
Value operand) {
692 assert(partitionMap.contains(operand));
693 return partitionMap.lookup(operand);
697 symbolTableCollection, builder);
706 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
711 for (
auto [unshardedBlockArg, partitionedBlockArg] :
713 partitionMap.
map(unshardedBlockArg, partitionedBlockArg);
736 for (
Block &
b : op.getBlocks()) {
737 if (llvm::any_of(
b.getOperations(),
738 [](
Operation &op) { return isa<ShardOp>(op); })) {
739 originalBlocks.push_back(&
b);
743 for (
Block *block : originalBlocks) {
744 if (failed(
partitionBlock(*block, partitionMap, symbolTableCollection,
750 for (
Block *block : originalBlocks) {
757 for (
Block &block : op.getFunctionBody()) {
763 returnOp = &block.back();
768 op.setType(FunctionType::get(
769 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
778struct Partition :
public impl::PartitionBase<Partition> {
779 void runOnOperation()
override {
782 if (failed(partitionFuncOp(getOperation(), partitionMap,
783 symbolTableCollection))) {
784 return signalPassFailure();
788 void getDependentDialects(DialectRegistry ®istry)
const override {
790 registry.insert<shard::ShardDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
bool isFullReplication(Sharding sharding)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(Sharding sourceSharding, Sharding targetSharding)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, Sharding sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".