29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
41template <
typename SourceAxes,
typename TargetAxes>
43 const TargetAxes &targetAxes) {
44 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
45 return sourceAxes.contains(targetAxis);
58 virtual std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
87 while (
static_cast<int64_t>(tgtShardingSplitAxes.size()) <=
92 llvm::to_vector(tgtShardingSplitAxes[splitTensorDim].asArrayRef());
93 tgtSplitAxes.push_back(splitGridAxis);
101 static std::tuple<TypedValue<ShapedType>,
Sharding>
106 AllSliceOp::create(builder, srcShard, grid,
110 tgtSharding(builder.
getContext(), std::move(srcSharding),
111 splitTensorDim, splitGridAxis);
112 return {tgtShard, resultSharding};
120 static std::optional<GridAxis> detect(
const Sharding &srcSharding,
123 if (
static_cast<size_t>(tensorDim) >= tgtSharding.getSplitAxes().size())
125 auto tgtAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
126 if (srcSharding.
getSplitAxes().size() >
static_cast<size_t>(tensorDim)) {
127 auto srcAxes = srcSharding.
getSplitAxes()[tensorDim].asArrayRef();
128 if (srcAxes.size() + 1 != tgtAxes.size())
130 if (!llvm::equal(srcAxes,
131 llvm::make_range(tgtAxes.begin(), tgtAxes.end() - 1)))
134 if (tgtAxes.size() != 1)
137 return tgtAxes.back();
141 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
144 ShapedType srcUnshardedType,
148 if (
auto gridAxis = detect(srcSharding, tgtSharding, tensorDim))
149 return apply(builder, srcSharding, srcShard, grid, tensorDim,
161 static std::optional<SmallVector<GridAxis>>
164 if (
static_cast<size_t>(tensorDim) >= srcSharding.
getSplitAxes().size())
167 auto srcSplitAxes = srcSharding.
getSplitAxes()[tensorDim].asArrayRef();
168 if (tgtSharding.getSplitAxes().size() >
static_cast<size_t>(tensorDim)) {
169 auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
172 if (srcSplitAxes.size() <= tgtSplitAxes.size())
176 if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
177 srcSplitAxes.begin()))
179 dimOff = tgtSplitAxes.size();
183 if (srcSplitAxes.size() == 0)
197 int64_t splitTensorDim,
size_t numUnsplitAxes) {
200 assert(
static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
202 assert(srcSplitAxes.size() >= numUnsplitAxes);
203 size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
205 srcSplitAxes.begin() + numSplitAxes);
212 static ShapedType allGatherResultType(ShapedType srcType,
217 for (
GridAxis gridAxis : unsplitAxes)
218 tgtShape[splitTensorDim] =
220 return srcType.cloneWith(tgtShape, srcType.getElementType());
225 static std::tuple<TypedValue<ShapedType>,
Sharding>
232 Sharding resultSharding = tgtSharding(ctx, std::move(srcSharding),
233 splitTensorDim, unsplitAxes.size());
234 ShapedType agResultType = allGatherResultType(
235 srcShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
236 Value allGatherResult = AllGatherOp::create(
238 RankedTensorType::get(agResultType.getShape(),
239 agResultType.getElementType()),
240 grid.getSymName(), unsplitAxes, srcShard, APInt(64, splitTensorDim));
244 tensor::CastOp::create(builder, tgtType, allGatherResult).getResult();
245 return {tgtShard, resultSharding};
249 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
252 ShapedType srcUnshardedType,
256 if (
auto gridAxes = detect(srcSharding, tgtSharding, tensorDim))
257 return apply(builder, srcSharding, srcUnshardedType, srcShard, grid,
258 tensorDim, gridAxes.value());
269 tgtShape[srcTensorDim] =
gatherDimension(tgtShape[srcTensorDim], splitCount);
270 tgtShape[tgtTensorDim] =
shardDimension(tgtShape[tgtTensorDim], splitCount);
271 return srcShape.cloneWith(tgtShape, srcShape.getElementType());
286 static std::optional<std::tuple<int64_t, GridAxis>>
289 if (
static_cast<size_t>(srcTensorDim) >= srcSharding.
getSplitAxes().size())
291 auto srcAxes = srcSharding.
getSplitAxes()[srcTensorDim].asArrayRef();
297 if (
static_cast<size_t>(srcTensorDim) >= tgtSharding.getSplitAxes().size())
299 auto tgtSrcAxes = tgtSharding.getSplitAxes()[srcTensorDim].asArrayRef();
300 if (tgtSrcAxes.size() + 1 != srcAxes.size())
303 if (!llvm::equal(tgtSrcAxes,
304 llvm::make_range(srcAxes.begin(), srcAxes.end() - 1)))
307 GridAxis movedAxis = srcAxes.back();
311 for (
size_t tgtTensorDim = 0;
312 tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
313 if (
static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
315 auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
317 if (tgtAxes.empty() || tgtAxes.front() != movedAxis)
322 static_cast<size_t>(tgtTensorDim) < srcSharding.
getSplitAxes().size()
325 if (!llvm::equal(srcTgtAxes,
326 llvm::make_range(tgtAxes.begin() + 1, tgtAxes.end())))
328 return std::make_tuple(
static_cast<int64_t>(tgtTensorDim), movedAxis);
340 while (
static_cast<int64_t>(splitAxes.size()) <= tgtTensorDim)
344 auto srcSplitAxes = llvm::to_vector(splitAxes[srcTensorDim].asArrayRef());
345 assert(!srcSplitAxes.empty() && srcSplitAxes.back() == movedAxis);
346 srcSplitAxes.pop_back();
350 auto tgtSplitAxes = llvm::to_vector(splitAxes[tgtTensorDim].asArrayRef());
351 tgtSplitAxes.insert(tgtSplitAxes.begin(), movedAxis);
357 static std::tuple<TypedValue<ShapedType>,
Sharding>
365 tgtSharding(ctx, srcSharding, srcTensorDim, tgtTensorDim, movedAxis);
366 ShapedType a2aResultShape =
368 srcTensorDim, tgtTensorDim);
369 Value allToAllResult = AllToAllOp::create(
371 RankedTensorType::get(a2aResultShape.getShape(),
372 a2aResultShape.getElementType()),
374 APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
375 ShapedType tgtShape =
378 tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
379 return {tgtShard, resultSharding};
383 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
386 ShapedType srcUnshardedType,
390 if (
auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
391 auto [tgtTensorDim, movedAxis] = detectRes.value();
392 return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
393 tensorDim, tgtTensorDim, movedAxis);
404 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
407 ShapedType srcUnshardedType,
422 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
423 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
424 ShapedType::isStaticShape(tgtHaloSizes) &&
425 srcShard.getType().hasStaticShape()) &&
426 "dynamic shapes/halos are not supported yet for shard-partition");
427 auto rank = srcShard.getType().getRank();
430 strides(rank, 1), outShape(srcShard.getType().getShape()),
431 coreShape(srcShard.getType().getShape());
435 for (
auto i = 0u; i < rank; ++i) {
436 if (i < splitAxes.size() && !splitAxes[i].empty()) {
437 if (!srcHaloSizes.empty()) {
438 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
439 srcCoreOffs[i] = srcHaloSizes[i * 2];
441 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
443 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
449 auto initVal = tensor::EmptyOp::create(builder, srcShard.
getLoc(), outShape,
450 srcShard.getType().getElementType());
451 auto core = tensor::ExtractSliceOp::create(
452 builder, srcShard.
getLoc(),
453 RankedTensorType::get(coreShape, srcShard.getType().getElementType()),
454 srcShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
455 auto initOprnd = tensor::InsertSliceOp::create(
456 builder, srcShard.
getLoc(), core, initVal, noVals, noVals, noVals,
457 tgtCoreOffs, coreShape, strides);
460 auto updateHaloResult =
461 UpdateHaloOp::create(builder, srcShard.
getLoc(),
462 RankedTensorType::get(
463 outShape, srcShard.getType().getElementType()),
464 initOprnd, grid.getSymName(),
478 GridOp grid,
const Sharding &srcSharding,
483 if (srcSharding == tgtSharding ||
488 assert(shardedSrc.getType() ==
490 [[maybe_unused]] ShapedType tgtShardType =
492 assert(shardedSrc.getType().getRank() == tgtShardType.getRank());
493 assert(unshardedSrc.getType().getRank() == tgtShardType.getRank());
501 &updateHaloPattern, &moveLastSplitAxisPattern, &splitLastAxisPattern,
502 &unsplitLastAxesPattern};
504 Sharding currentSharding = srcSharding;
506 dim < tgtShardType.getRank() && currentSharding != tgtSharding; ++dim) {
507 for (
auto &pattern : patterns) {
508 if (
auto tryRes = pattern->tryApply(builder, grid, dim, currentSharding,
509 tgtSharding, unshardedSrc.getType(),
511 std::tie(currentShard, currentSharding) = tryRes.value();
517 if (currentSharding != tgtSharding ||
518 currentShard.getType() != tgtShardType) {
520 <<
"Failed to reshard; probably hitting an unknown resharding pattern:"
521 <<
" got " << currentSharding <<
" expected " << tgtSharding
522 <<
" got type " << currentShard.getType() <<
" expected "
530 ShardOp srcShardOp, ShardOp tgtShardOp,
532 assert(srcShardOp.getResult() == tgtShardOp.getSrc());
533 auto srcSharding = srcShardOp.getSharding();
534 auto tgtSharding = tgtShardOp.getSharding();
536 return reshard(implicitLocOpBuilder, grid, srcSharding, tgtSharding,
537 srcShardOp.getSrc(), shardedSrc);
544 GridOp srcGrid =
getGrid(srcShardOp, symbolTableCollection);
545 assert(srcGrid && srcGrid ==
getGrid(tgtShardOp, symbolTableCollection));
546 return reshard(builder, srcGrid, srcShardOp, tgtShardOp, shardedSrc);
550 registry.
insert<shard::ShardDialect, tensor::TensorDialect>();
553#define GEN_PASS_DEF_PARTITION
554#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
568 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
569 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
570 rankedTensorArg.use_empty()) {
571 return arg.getType();
574 assert(rankedTensorArg.hasOneUse());
576 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
578 GridOp grid =
getGrid(shardOp, symbolTableCollection);
580 shardOp.getSharding()));
591 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
592 if (!shardingInterface) {
596 resultShardings, partitionMap,
597 symbolTableCollection, builder);
599 if (failed(shardingInterface.partition(
600 partitionedOperands, operandShardings, resultShardings,
601 partitionMap, symbolTableCollection, builder))) {
607 return partitionMap.contains(result);
616 std::vector<Sharding> res;
618 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
619 TypedValue<RankedTensorType> rankedTensor =
620 dyn_cast<TypedValue<RankedTensorType>>(operand);
621 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
627 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
628 return Sharding(shardOp.getSharding());
636 std::vector<Sharding> res;
640 if (!result.hasOneUse() || result.use_empty()) {
649 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
651 return Sharding(shardOp.getSharding());
653 if (rankedTensor.getType().getRank() == 0) {
658 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
659 return Sharding(sharding.getGridAttr());
672 Value tgtPartitionValue;
676 ShardOp srcShardOp = shardOp.getSrc().
getDefiningOp<ShardOp>();
678 tgtPartitionValue = partitionMap.
lookup(shardOp.getSrc());
682 cast<TypedValue<ShapedType>>(partitionMap.
lookup(srcShardOp));
683 tgtPartitionValue =
reshard(builder, srcShardOp, shardOp, shardedSrc,
684 symbolTableCollection);
685 if (!tgtPartitionValue) {
686 return shardOp.emitError()
687 <<
"Failed to reshard from " << srcShardOp.getSharding() <<
" to "
688 << shardOp.getSharding();
692 assert(!partitionMap.
contains(shardOp.getResult()));
693 partitionMap.
map(shardOp.getResult(), tgtPartitionValue);
703 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
704 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
705 rankedTensorArg.use_empty())
708 if (!rankedTensorArg.hasOneUse())
710 <<
"Cannot partition: expected a single use for block argument "
711 << arg.getArgNumber() <<
" in block "
715 auto shardOp = dyn_cast<ShardOp>(useOp);
718 <<
"Cannot partition: expected a shard.shard op for block "
719 <<
"argument " << arg.getArgNumber() <<
" in block "
740 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
741 if (!rankedTT || rankedTT.getRank() == 0)
744 auto shard = operand.get().getDefiningOp<ShardOp>();
746 return op->
emitError() <<
"Cannot partition: tensor operand "
747 << operand.getOperandNumber()
748 <<
" must be defined by a shard.shard operation.";
749 if (!
shard.getAnnotateForUsers())
751 <<
"Cannot partition: shard.shard for operand "
752 << operand.getOperandNumber() <<
" must set 'annotate_for_users'.";
757 <<
"Cannot partition: result " <<
result.getResultNumber()
758 <<
" must have exactly one use.";
759 auto shard = dyn_cast<ShardOp>(*
result.user_begin());
762 <<
"Cannot partition: user of result " <<
result.getResultNumber()
763 <<
" must be shard.shard operation.";
764 if (
shard.getAnnotateForUsers())
765 return op->
emitError() <<
"Cannot partition: shard.shard for result "
766 <<
result.getResultNumber()
767 <<
" must not set 'annotate_for_users'.";
776 if (isa<ShardingOp>(op)) {
780 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
781 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
783 return op.
emitError(
"expected a shard op as source of get_sharding");
785 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
786 partitionMap.
map(op.
getResult(0), newSharding->getResult(0));
790 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
801 llvm::transform(op.
getOperands(), std::back_inserter(partitionedOperands),
802 [&partitionMap](
Value operand) {
803 assert(partitionMap.contains(operand));
804 return partitionMap.lookup(operand);
808 symbolTableCollection, builder);
820 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
825 for (
auto [unshardedBlockArg, partitionedBlockArg] :
827 partitionMap.
map(unshardedBlockArg, partitionedBlockArg);
850 for (
Block &
b : op.getBlocks()) {
851 if (llvm::any_of(
b.getOperations(),
852 [](
Operation &op) { return isa<ShardOp>(op); })) {
853 originalBlocks.push_back(&
b);
857 for (
Block *block : originalBlocks) {
858 if (failed(
partitionBlock(*block, partitionMap, symbolTableCollection,
864 for (
Block *block : originalBlocks) {
871 for (
Block &block : op.getFunctionBody()) {
877 returnOp = &block.back();
882 op.setType(FunctionType::get(
883 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
892struct Partition :
public impl::PartitionBase<Partition> {
893 void runOnOperation()
override {
896 if (failed(partitionFuncOp(getOperation(), partitionMap,
897 symbolTableCollection))) {
898 return signalPassFailure();
902 void getDependentDialects(DialectRegistry ®istry)
const override {
904 registry.insert<shard::ShardDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
mlir::InFlightDiagnostic emitError(const llvm::Twine &message=llvm::Twine())
This builder can also be used to emit diagnostics to the current location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Location getLoc()
Return a location for this region.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Move the last split axis of one tensor dimension to the front of another tensor dimension's split axe...
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Base class for resharding patterns.
virtual std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard)=0
Try to apply this resharding pattern.
static bool hasStaticOffsetsOrHalos(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets or non-empty static halo si...
static bool hasStaticOffsets(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets.
virtual ~ReshardingPattern()=default
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]].
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> [].
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Update halo sizes: handles cases where only the halo sizes differ between source and target sharding.
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static ShapedType allToAllResultShape(ShapedType srcShape, int64_t splitCount, int64_t srcTensorDim, int64_t tgtTensorDim)
static LogicalResult checkFullyAnnotated(Block &block)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
bool isFullReplication(Sharding sharding)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".