29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
39template <
typename SourceAxes,
typename TargetAxes>
41 const TargetAxes &targetAxes) {
42 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
43 return sourceAxes.contains(targetAxis);
53 while (
static_cast<int64_t>(targetShardingSplitAxes.size()) <=
57 auto targetSplitAxes =
58 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
59 targetSplitAxes.push_back(splitGridAxis);
60 targetShardingSplitAxes[splitTensorAxis] =
68static std::tuple<TypedValue<ShapedType>, Sharding>
74 AllSliceOp::create(builder, sourceShard, grid,
78 builder.
getContext(), std::move(sourceSharding), splitTensorAxis,
80 return {targetShard, targetSharding};
88static std::optional<std::tuple<int64_t, GridAxis>>
91 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.
getSplitAxes().size();
94 if (sourceSharding.
getSplitAxes()[tensorAxis].size() + 1 !=
104 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
109 if (targetSharding.
getSplitAxes()[tensorAxis].size() != 1) {
113 return std::make_tuple(
115 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
120static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
126 sourceSharding, std::move(targetSharding))) {
127 auto [tensorAxis, gridAxis] = detectRes.value();
129 tensorAxis, gridAxis);
138static std::optional<std::tuple<int64_t, GridAxis>>
141 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.
getSplitAxes().size();
143 if (targetSharding.
getSplitAxes().size() > tensorAxis) {
152 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
154 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef()))
157 if (sourceSharding.
getSplitAxes()[tensorAxis].size() != 1)
160 return std::make_tuple(
162 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
172 assert(
static_cast<int64_t>(targetShardingSplitAxes.size()) >
174 auto targetSplitAxes =
175 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
177 targetSplitAxes.pop_back();
178 targetShardingSplitAxes[splitTensorAxis] =
184 ShapedType sourceShape,
int64_t splitCount,
int64_t splitTensorAxis) {
186 targetShape[splitTensorAxis] =
188 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
199 ctx, std::move(sourceSharding), splitTensorAxis);
201 sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis);
202 Value allGatherResult = AllGatherOp::create(
204 RankedTensorType::get(allGatherResultShape.getShape(),
205 allGatherResultShape.getElementType()),
207 APInt(64, splitTensorAxis));
208 ShapedType targetShape =
211 tensor::CastOp::create(builder, targetShape, allGatherResult).getResult();
212 return {targetShard, targetSharding};
215static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
219 ShapedType sourceUnshardedShape,
222 sourceSharding, std::move(targetSharding))) {
223 auto [tensorAxis, gridAxis] = detectRes.value();
225 sourceUnshardedShape, sourceShard, grid,
226 tensorAxis, gridAxis);
237static std::optional<std::tuple<int64_t, int64_t, GridAxis>>
240 for (
size_t sourceTensorAxis = 0;
241 sourceTensorAxis < sourceSharding.
getSplitAxes().size();
242 ++sourceTensorAxis) {
243 for (
size_t targetTensorAxis = 0;
244 targetTensorAxis < targetSharding.
getSplitAxes().size();
245 ++targetTensorAxis) {
246 if (sourceTensorAxis == targetTensorAxis)
248 if (sourceSharding.
getSplitAxes()[sourceTensorAxis].empty() ||
249 targetSharding.
getSplitAxes()[targetTensorAxis].empty() ||
250 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
256 llvm::make_range(sourceSharding.
getSplitAxes()[sourceTensorAxis]
263 llvm::make_range(targetSharding.
getSplitAxes()[targetTensorAxis]
271 return std::make_tuple(
272 sourceTensorAxis, targetTensorAxis,
273 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back());
285 while (
static_cast<int64_t>(targetShardingSplitAxes.size()) <=
290 auto sourceSplitAxes =
291 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
292 assert(!sourceSplitAxes.empty());
293 auto gridAxis = sourceSplitAxes.back();
294 sourceSplitAxes.pop_back();
295 targetShardingSplitAxes[sourceTensorAxis] =
298 auto targetSplitAxes =
299 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
300 targetSplitAxes.push_back(gridAxis);
301 targetShardingSplitAxes[targetTensorAxis] =
312 targetShape[sourceTensorAxis] =
314 targetShape[targetTensorAxis] =
316 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
319static std::tuple<TypedValue<ShapedType>, Sharding>
322 ShapedType sourceUnshardedShape,
330 ctx, std::move(sourceSharding), sourceTensorAxis, targetTensorAxis);
332 sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis,
334 Value allToAllResult = AllToAllOp::create(
336 RankedTensorType::get(allToAllResultShape.getShape(),
337 allToAllResultShape.getElementType()),
339 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
340 ShapedType targetShape =
343 tensor::CastOp::create(builder, targetShape, allToAllResult).getResult();
344 return {targetShard, targetSharding};
347static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
351 ShapedType sourceUnshardedShape,
354 sourceSharding, std::move(targetSharding))) {
355 auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value();
357 builder, grid, sourceSharding, sourceUnshardedShape, sourceShard,
358 sourceTensorAxis, targetTensorAxis, gridAxis);
368static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>>
372 ShapedType sourceUnshardedShape,
385 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
386 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
387 ShapedType::isStaticShape(tgtHaloSizes) &&
388 sourceShard.getType().hasStaticShape()) &&
389 "dynamic shapes/halos are not supported yet for shard-partition");
390 auto rank = sourceShard.getType().getRank();
393 strides(rank, 1), outShape(sourceShard.getType().getShape()),
394 coreShape(sourceShard.getType().getShape());
398 for (
auto i = 0u; i < rank; ++i) {
399 if (i < splitAxes.size() && !splitAxes[i].empty()) {
400 if (!srcHaloSizes.empty()) {
401 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
402 srcCoreOffs[i] = srcHaloSizes[i * 2];
404 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
406 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
413 tensor::EmptyOp::create(builder, sourceShard.
getLoc(), outShape,
414 sourceShard.getType().getElementType());
415 auto core = tensor::ExtractSliceOp::create(
416 builder, sourceShard.
getLoc(),
417 RankedTensorType::get(coreShape, sourceShard.getType().getElementType()),
418 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
419 auto initOprnd = tensor::InsertSliceOp::create(
420 builder, sourceShard.
getLoc(), core, initVal, noVals, noVals, noVals,
421 tgtCoreOffs, coreShape, strides);
424 auto updateHaloResult =
425 UpdateHaloOp::create(
426 builder, sourceShard.
getLoc(),
427 RankedTensorType::get(outShape,
428 sourceShard.getType().getElementType()),
429 initOprnd, grid.getSymName(),
447 assert(sourceShard.getType() ==
448 shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding));
449 [[maybe_unused]] ShapedType targetShardType =
451 assert(sourceShard.getType().getRank() == targetShardType.getRank());
452 assert(grid.getRank() == 1 &&
"Only 1D grides are currently supported.");
454 if (sourceSharding == targetSharding) {
465 builder, grid, sourceSharding, targetSharding,
466 sourceUnshardedValue.getType(), sourceShard)) {
467 std::tie(targetShard, actualTargetSharding) = tryRes.value();
468 }
else if (
auto tryRes =
470 targetSharding, sourceShard)) {
471 std::tie(targetShard, actualTargetSharding) = tryRes.value();
473 builder, grid, sourceSharding, targetSharding,
474 sourceUnshardedValue.getType(), sourceShard)) {
475 std::tie(targetShard, actualTargetSharding) = tryRes.value();
478 assert(targetShard &&
"Did not find any pattern to apply.");
479 assert(actualTargetSharding == targetSharding);
480 assert(targetShard.getType() == targetShardType);
498 builder, grid, sourceSharding, targetSharding,
499 sourceUnshardedValue.getType(), sourceShard)) {
500 return std::get<0>(tryRes.value());
507 sourceUnshardedValue, sourceShard);
513 assert(source.getResult() ==
target.getSrc());
514 auto sourceSharding = source.getSharding();
515 auto targetSharding =
target.getSharding();
517 return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding,
518 source.getSrc(), sourceShardValue);
525 GridOp srcGrid =
getGrid(source, symbolTableCollection);
526 assert(srcGrid && srcGrid ==
getGrid(
target, symbolTableCollection));
527 return reshard(builder, srcGrid, source,
target, sourceShardValue);
531 registry.
insert<shard::ShardDialect, tensor::TensorDialect>();
534#define GEN_PASS_DEF_PARTITION
535#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
549 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
550 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
551 return arg.getType();
554 assert(rankedTensorArg.hasOneUse());
556 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
558 GridOp grid =
getGrid(shardOp, symbolTableCollection);
560 shardOp.getSharding()));
571 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
572 if (!shardingInterface) {
576 resultShardings, partitionMap,
577 symbolTableCollection, builder);
579 if (failed(shardingInterface.partition(
580 partitionedOperands, operandShardings, resultShardings,
581 partitionMap, symbolTableCollection, builder))) {
587 return partitionMap.contains(result);
596 std::vector<Sharding> res;
598 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
599 TypedValue<RankedTensorType> rankedTensor =
600 dyn_cast<TypedValue<RankedTensorType>>(operand);
601 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
607 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
608 return Sharding(shardOp.getSharding());
616 std::vector<Sharding> res;
620 if (!result.hasOneUse() || result.use_empty()) {
629 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
631 return Sharding(shardOp.getSharding());
633 if (rankedTensor.getType().getRank() == 0) {
638 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
639 return Sharding(sharding.getGridAttr());
652 Value targetPartitionValue;
656 ShardOp srcShardOp = shardOp.getSrc().
getDefiningOp<ShardOp>();
658 targetPartitionValue = partitionMap.
lookup(shardOp.getSrc());
662 cast<TypedValue<ShapedType>>(partitionMap.
lookup(srcShardOp));
663 targetPartitionValue =
reshard(builder, srcShardOp, shardOp,
664 srcPartitionValue, symbolTableCollection);
667 assert(!partitionMap.
contains(shardOp.getResult()));
668 partitionMap.
map(shardOp.getResult(), targetPartitionValue);
678 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
679 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0)
682 if (rankedTensorArg.getNumUses() > 1)
684 <<
"Cannot partition: expected a single use for block argument "
685 << arg.getArgNumber() <<
" in block "
688 auto shardOp = dyn_cast<ShardOp>(useOp);
691 <<
"Cannot partition: expected a shard.shard op for block "
692 <<
"argument " << arg.getArgNumber() <<
" in block "
713 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
714 if (!rankedTT || rankedTT.getRank() == 0)
717 auto shard = operand.get().getDefiningOp<ShardOp>();
719 return op->
emitError() <<
"Cannot partition: tensor operand "
720 << operand.getOperandNumber()
721 <<
" must be defined by a shard.shard operation.";
722 if (!
shard.getAnnotateForUsers())
724 <<
"Cannot partition: shard.shard for operand "
725 << operand.getOperandNumber() <<
" must set 'annotate_for_users'.";
730 <<
"Cannot partition: result " <<
result.getResultNumber()
731 <<
" must have exactly one use.";
732 auto shard = dyn_cast<ShardOp>(*
result.user_begin());
735 <<
"Cannot partition: user of result " <<
result.getResultNumber()
736 <<
" must be shard.shard operation.";
737 if (
shard.getAnnotateForUsers())
738 return op->
emitError() <<
"Cannot partition: shard.shard for result "
739 <<
result.getResultNumber()
740 <<
" must not set 'annotate_for_users'.";
749 if (isa<ShardingOp>(op)) {
753 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
754 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
756 return op.
emitError(
"expected a shard op as source of get_sharding");
758 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
759 partitionMap.
map(op.
getResult(0), newSharding->getResult(0));
763 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
774 llvm::transform(op.
getOperands(), std::back_inserter(partitionedOperands),
775 [&partitionMap](
Value operand) {
776 assert(partitionMap.contains(operand));
777 return partitionMap.lookup(operand);
781 symbolTableCollection, builder);
793 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
798 for (
auto [unshardedBlockArg, partitionedBlockArg] :
800 partitionMap.
map(unshardedBlockArg, partitionedBlockArg);
823 for (
Block &
b : op.getBlocks()) {
824 if (llvm::any_of(
b.getOperations(),
825 [](
Operation &op) { return isa<ShardOp>(op); })) {
826 originalBlocks.push_back(&
b);
830 for (
Block *block : originalBlocks) {
831 if (failed(
partitionBlock(*block, partitionMap, symbolTableCollection,
837 for (
Block *block : originalBlocks) {
844 for (
Block &block : op.getFunctionBody()) {
850 returnOp = &block.back();
855 op.setType(FunctionType::get(
856 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
865struct Partition :
public impl::PartitionBase<Partition> {
866 void runOnOperation()
override {
869 if (failed(partitionFuncOp(getOperation(), partitionMap,
870 symbolTableCollection))) {
871 return signalPassFailure();
875 void getDependentDialects(DialectRegistry ®istry)
const override {
877 registry.insert<shard::ShardDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Location getLoc()
Return a location for this region.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::tuple< TypedValue< ShapedType >, Sharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
static TypedValue< ShapedType > reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, Sharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, GridAxis gridAxis)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
static std::optional< std::tuple< int64_t, GridAxis > > detectUnsplitLastAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::tuple< TypedValue< ShapedType >, Sharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static LogicalResult checkFullyAnnotated(Block &block)
static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis, GridAxis splitGridAxis)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
bool isFullReplication(Sharding sharding)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, GridAxis > > detectMoveLastSplitAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, const Sharding &targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::vector< Sharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, const Sharding &sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, const Sharding &sourceSharding, int64_t splitTensorAxis)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static std::optional< std::tuple< int64_t, GridAxis > > detectSplitLastAxisInResharding(const Sharding &sourceSharding, const Sharding &targetSharding)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".