29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/Support/Casting.h"
41template <
typename SourceAxes,
typename TargetAxes>
43 const TargetAxes &targetAxes) {
44 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
45 return sourceAxes.contains(targetAxis);
58 virtual std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
87 while (
static_cast<int64_t>(tgtShardingSplitAxes.size()) <=
92 llvm::to_vector(tgtShardingSplitAxes[splitTensorDim].asArrayRef());
93 tgtSplitAxes.push_back(splitGridAxis);
101 static std::tuple<TypedValue<ShapedType>,
Sharding>
106 AllSliceOp::create(builder, srcShard, grid,
110 tgtSharding(builder.
getContext(), std::move(srcSharding),
111 splitTensorDim, splitGridAxis);
112 return {tgtShard, resultSharding};
120 static std::optional<GridAxis> detect(
const Sharding &srcSharding,
123 if (
static_cast<size_t>(tensorDim) >= tgtSharding.getSplitAxes().size())
125 auto tgtAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
126 if (srcSharding.
getSplitAxes().size() >
static_cast<size_t>(tensorDim)) {
127 auto srcAxes = srcSharding.
getSplitAxes()[tensorDim].asArrayRef();
128 if (srcAxes.size() + 1 != tgtAxes.size())
130 if (!llvm::equal(srcAxes,
131 llvm::make_range(tgtAxes.begin(), tgtAxes.end() - 1)))
134 if (tgtAxes.size() != 1)
137 return tgtAxes.back();
141 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
144 ShapedType srcUnshardedType,
148 if (
auto gridAxis = detect(srcSharding, tgtSharding, tensorDim))
149 return apply(builder, srcSharding, srcShard, grid, tensorDim,
161 static std::optional<SmallVector<GridAxis>>
164 if (
static_cast<size_t>(tensorDim) >= srcSharding.
getSplitAxes().size())
167 auto srcSplitAxes = srcSharding.
getSplitAxes()[tensorDim].asArrayRef();
168 if (tgtSharding.getSplitAxes().size() >
static_cast<size_t>(tensorDim)) {
169 auto tgtSplitAxes = tgtSharding.getSplitAxes()[tensorDim].asArrayRef();
172 if (srcSplitAxes.size() <= tgtSplitAxes.size())
176 if (!std::equal(tgtSplitAxes.begin(), tgtSplitAxes.end(),
177 srcSplitAxes.begin()))
179 dimOff = tgtSplitAxes.size();
183 if (srcSplitAxes.size() == 0)
197 int64_t splitTensorDim,
size_t numUnsplitAxes) {
200 assert(
static_cast<int64_t>(resSplitAxes.size()) > splitTensorDim);
202 assert(srcSplitAxes.size() >= numUnsplitAxes);
203 size_t numSplitAxes = srcSplitAxes.size() - numUnsplitAxes;
205 srcSplitAxes.begin() + numSplitAxes);
212 static ShapedType allGatherResultType(ShapedType srcType,
217 for (
GridAxis gridAxis : unsplitAxes)
218 tgtShape[splitTensorDim] =
220 return srcType.cloneWith(tgtShape, srcType.getElementType());
225 static std::tuple<TypedValue<ShapedType>,
Sharding>
232 Sharding resultSharding = tgtSharding(ctx, std::move(srcSharding),
233 splitTensorDim, unsplitAxes.size());
234 ShapedType agResultType = allGatherResultType(
235 srcShard.getType(), splitTensorDim, grid.getShape(), unsplitAxes);
236 Value allGatherResult = AllGatherOp::create(
238 RankedTensorType::get(agResultType.getShape(),
239 agResultType.getElementType()),
240 grid.getSymName(), unsplitAxes, srcShard, APInt(64, splitTensorDim));
244 tensor::CastOp::create(builder, tgtType, allGatherResult).getResult();
245 return {tgtShard, resultSharding};
249 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
252 ShapedType srcUnshardedType,
256 if (
auto gridAxes = detect(srcSharding, tgtSharding, tensorDim))
257 return apply(builder, srcSharding, srcUnshardedType, srcShard, grid,
258 tensorDim, gridAxes.value());
269 static std::optional<std::tuple<int64_t, GridAxis>>
272 if (
static_cast<size_t>(srcTensorDim) >= srcSharding.
getSplitAxes().size())
274 auto srcAxes = srcSharding.
getSplitAxes()[srcTensorDim].asArrayRef();
275 if (srcAxes.size() != 1)
277 for (
size_t tgtTensorDim = 0;
278 tgtTensorDim < tgtSharding.getSplitAxes().size(); ++tgtTensorDim) {
279 if (
static_cast<int64_t>(tgtTensorDim) == srcTensorDim)
281 auto tgtAxes = tgtSharding.getSplitAxes()[tgtTensorDim].asArrayRef();
282 if (tgtAxes.size() != 1 || srcAxes.front() != tgtAxes.front())
284 return std::make_tuple(
static_cast<int64_t>(tgtTensorDim),
294 while (
static_cast<int64_t>(tgtShardingSplitAxes.size()) <= tgtTensorDim) {
299 llvm::to_vector(tgtShardingSplitAxes[srcTensorDim].asArrayRef());
300 assert(srcSplitAxes.size() == 1);
301 auto gridAxis = srcSplitAxes.back();
302 srcSplitAxes.pop_back();
306 llvm::to_vector(tgtShardingSplitAxes[tgtTensorDim].asArrayRef());
307 tgtSplitAxes.push_back(gridAxis);
313 static ShapedType allToAllResultShape(ShapedType srcShape,
int64_t splitCount,
317 tgtShape[srcTensorDim] =
319 tgtShape[tgtTensorDim] =
shardDimension(tgtShape[tgtTensorDim], splitCount);
320 return srcShape.cloneWith(tgtShape, srcShape.getElementType());
323 static std::tuple<TypedValue<ShapedType>,
Sharding>
331 tgtSharding(ctx, std::move(srcSharding), srcTensorDim, tgtTensorDim);
332 ShapedType a2aResultShape =
333 allToAllResultShape(srcShard.getType(), grid.getShape()[gridAxis],
334 srcTensorDim, tgtTensorDim);
335 Value allToAllResult = AllToAllOp::create(
337 RankedTensorType::get(a2aResultShape.getShape(),
338 a2aResultShape.getElementType()),
340 APInt(64, tgtTensorDim), APInt(64, srcTensorDim));
341 ShapedType tgtShape =
344 tensor::CastOp::create(builder, tgtShape, allToAllResult).getResult();
345 return {tgtShard, resultSharding};
349 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
352 ShapedType srcUnshardedType,
356 if (
auto detectRes = detect(srcSharding, tgtSharding, tensorDim)) {
357 auto [tgtTensorDim, gridAxis] = detectRes.value();
358 return apply(builder, grid, srcSharding, srcUnshardedType, srcShard,
359 tensorDim, tgtTensorDim, gridAxis);
370 std::optional<std::tuple<TypedValue<ShapedType>,
Sharding>>
373 ShapedType srcUnshardedType,
388 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
389 assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) &&
390 ShapedType::isStaticShape(tgtHaloSizes) &&
391 srcShard.getType().hasStaticShape()) &&
392 "dynamic shapes/halos are not supported yet for shard-partition");
393 auto rank = srcShard.getType().getRank();
396 strides(rank, 1), outShape(srcShard.getType().getShape()),
397 coreShape(srcShard.getType().getShape());
401 for (
auto i = 0u; i < rank; ++i) {
402 if (i < splitAxes.size() && !splitAxes[i].empty()) {
403 if (!srcHaloSizes.empty()) {
404 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
405 srcCoreOffs[i] = srcHaloSizes[i * 2];
407 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
409 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
415 auto initVal = tensor::EmptyOp::create(builder, srcShard.
getLoc(), outShape,
416 srcShard.getType().getElementType());
417 auto core = tensor::ExtractSliceOp::create(
418 builder, srcShard.
getLoc(),
419 RankedTensorType::get(coreShape, srcShard.getType().getElementType()),
420 srcShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
421 auto initOprnd = tensor::InsertSliceOp::create(
422 builder, srcShard.
getLoc(), core, initVal, noVals, noVals, noVals,
423 tgtCoreOffs, coreShape, strides);
426 auto updateHaloResult =
427 UpdateHaloOp::create(builder, srcShard.
getLoc(),
428 RankedTensorType::get(
429 outShape, srcShard.getType().getElementType()),
430 initOprnd, grid.getSymName(),
444 GridOp grid,
const Sharding &srcSharding,
449 if (srcSharding == tgtSharding ||
454 assert(shardedSrc.getType() ==
456 [[maybe_unused]] ShapedType tgtShardType =
458 assert(shardedSrc.getType().getRank() == tgtShardType.getRank());
459 assert(unshardedSrc.getType().getRank() == tgtShardType.getRank());
467 &updateHaloPattern, &moveSplitAxisPattern, &splitLastAxisPattern,
468 &unsplitLastAxesPattern};
470 Sharding currentSharding = srcSharding;
472 dim < tgtShardType.getRank() && currentSharding != tgtSharding; ++dim) {
473 for (
auto &pattern : patterns) {
474 if (
auto tryRes = pattern->tryApply(builder, grid, dim, currentSharding,
475 tgtSharding, unshardedSrc.getType(),
477 std::tie(currentShard, currentSharding) = tryRes.value();
483 if (currentSharding != tgtSharding ||
484 currentShard.getType() != tgtShardType) {
486 <<
"Failed to reshard; probably hitting an unknown resharding pattern:"
487 <<
" got " << currentSharding <<
" expected " << tgtSharding
488 <<
" got type " << currentShard.getType() <<
" expected "
496 ShardOp srcShardOp, ShardOp tgtShardOp,
498 assert(srcShardOp.getResult() == tgtShardOp.getSrc());
499 auto srcSharding = srcShardOp.getSharding();
500 auto tgtSharding = tgtShardOp.getSharding();
502 return reshard(implicitLocOpBuilder, grid, srcSharding, tgtSharding,
503 srcShardOp.getSrc(), shardedSrc);
510 GridOp srcGrid =
getGrid(srcShardOp, symbolTableCollection);
511 assert(srcGrid && srcGrid ==
getGrid(tgtShardOp, symbolTableCollection));
512 return reshard(builder, srcGrid, srcShardOp, tgtShardOp, shardedSrc);
516 registry.
insert<shard::ShardDialect, tensor::TensorDialect>();
519#define GEN_PASS_DEF_PARTITION
520#include "mlir/Dialect/Shard/Transforms/Passes.h.inc"
534 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
535 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
536 rankedTensorArg.use_empty()) {
537 return arg.getType();
540 assert(rankedTensorArg.hasOneUse());
542 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
544 GridOp grid =
getGrid(shardOp, symbolTableCollection);
546 shardOp.getSharding()));
557 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
558 if (!shardingInterface) {
562 resultShardings, partitionMap,
563 symbolTableCollection, builder);
565 if (failed(shardingInterface.partition(
566 partitionedOperands, operandShardings, resultShardings,
567 partitionMap, symbolTableCollection, builder))) {
573 return partitionMap.contains(result);
582 std::vector<Sharding> res;
584 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
585 TypedValue<RankedTensorType> rankedTensor =
586 dyn_cast<TypedValue<RankedTensorType>>(operand);
587 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
593 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
594 return Sharding(shardOp.getSharding());
602 std::vector<Sharding> res;
606 if (!result.hasOneUse() || result.use_empty()) {
615 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
617 return Sharding(shardOp.getSharding());
619 if (rankedTensor.getType().getRank() == 0) {
624 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
625 return Sharding(sharding.getGridAttr());
638 Value tgtPartitionValue;
642 ShardOp srcShardOp = shardOp.getSrc().
getDefiningOp<ShardOp>();
644 tgtPartitionValue = partitionMap.
lookup(shardOp.getSrc());
648 cast<TypedValue<ShapedType>>(partitionMap.
lookup(srcShardOp));
649 tgtPartitionValue =
reshard(builder, srcShardOp, shardOp, shardedSrc,
650 symbolTableCollection);
651 if (!tgtPartitionValue) {
652 return shardOp.emitError()
653 <<
"Failed to reshard from " << srcShardOp.getSharding() <<
" to "
654 << shardOp.getSharding();
658 assert(!partitionMap.
contains(shardOp.getResult()));
659 partitionMap.
map(shardOp.getResult(), tgtPartitionValue);
669 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
670 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0 ||
671 rankedTensorArg.use_empty())
674 if (!rankedTensorArg.hasOneUse())
676 <<
"Cannot partition: expected a single use for block argument "
677 << arg.getArgNumber() <<
" in block "
681 auto shardOp = dyn_cast<ShardOp>(useOp);
684 <<
"Cannot partition: expected a shard.shard op for block "
685 <<
"argument " << arg.getArgNumber() <<
" in block "
706 auto rankedTT = dyn_cast<RankedTensorType>(operand.get().getType());
707 if (!rankedTT || rankedTT.getRank() == 0)
710 auto shard = operand.get().getDefiningOp<ShardOp>();
712 return op->
emitError() <<
"Cannot partition: tensor operand "
713 << operand.getOperandNumber()
714 <<
" must be defined by a shard.shard operation.";
715 if (!
shard.getAnnotateForUsers())
717 <<
"Cannot partition: shard.shard for operand "
718 << operand.getOperandNumber() <<
" must set 'annotate_for_users'.";
723 <<
"Cannot partition: result " <<
result.getResultNumber()
724 <<
" must have exactly one use.";
725 auto shard = dyn_cast<ShardOp>(*
result.user_begin());
728 <<
"Cannot partition: user of result " <<
result.getResultNumber()
729 <<
" must be shard.shard operation.";
730 if (
shard.getAnnotateForUsers())
731 return op->
emitError() <<
"Cannot partition: shard.shard for result "
732 <<
result.getResultNumber()
733 <<
" must not set 'annotate_for_users'.";
742 if (isa<ShardingOp>(op)) {
746 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
747 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
749 return op.
emitError(
"expected a shard op as source of get_sharding");
751 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
752 partitionMap.
map(op.
getResult(0), newSharding->getResult(0));
756 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
767 llvm::transform(op.
getOperands(), std::back_inserter(partitionedOperands),
768 [&partitionMap](
Value operand) {
769 assert(partitionMap.contains(operand));
770 return partitionMap.lookup(operand);
774 symbolTableCollection, builder);
786 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
791 for (
auto [unshardedBlockArg, partitionedBlockArg] :
793 partitionMap.
map(unshardedBlockArg, partitionedBlockArg);
816 for (
Block &
b : op.getBlocks()) {
817 if (llvm::any_of(
b.getOperations(),
818 [](
Operation &op) { return isa<ShardOp>(op); })) {
819 originalBlocks.push_back(&
b);
823 for (
Block *block : originalBlocks) {
824 if (failed(
partitionBlock(*block, partitionMap, symbolTableCollection,
830 for (
Block *block : originalBlocks) {
837 for (
Block &block : op.getFunctionBody()) {
843 returnOp = &block.back();
848 op.setType(FunctionType::get(
849 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
858struct Partition :
public impl::PartitionBase<Partition> {
859 void runOnOperation()
override {
862 if (failed(partitionFuncOp(getOperation(), partitionMap,
863 symbolTableCollection))) {
864 return signalPassFailure();
868 void getDependentDialects(DialectRegistry ®istry)
const override {
870 registry.insert<shard::ShardDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
unsigned computeBlockNumber()
Compute the position of this block within its parent region using an O(N) linear scan.
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Location getLoc() const
Accessors for the implied location.
mlir::InFlightDiagnostic emitError(const llvm::Twine &message=llvm::Twine())
This builder can also be used to emit diagnostics to the current location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
This class provides the API for a sub-set of ops that are known to be constant-like.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
Location getLoc()
Return a location for this region.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int16_t > content)
Move a split axis between tensor dimensions: e.g.
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Base class for resharding patterns.
virtual std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard)=0
Try to apply this resharding pattern.
static bool hasStaticOffsetsOrHalos(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets or non-empty static halo si...
static bool hasStaticOffsets(const Sharding &srcSharding, const Sharding &tgtSharding)
Returns true if either sharding has non-empty static sharded dims offsets.
virtual ~ReshardingPattern()=default
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
bool equalSplitAxes(const Sharding &rhs) const
ArrayRef< int64_t > getStaticHaloSizes() const
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
ArrayRef< GridAxesAttr > getSplitAxes() const
bool equalHaloSizes(const Sharding &rhs) const
Split a replicated axis: e.g. [[0, 1]] -> [[0, 1, 2]].
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Unsplit trailing axes: e.g. [[0, 1, 2]] -> [[0, 1]] or [[0, 1, 2]] -> [].
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
Update halo sizes: handles cases where only the halo sizes differ between source and target sharding.
std::optional< std::tuple< TypedValue< ShapedType >, Sharding > > tryApply(ImplicitLocOpBuilder &builder, GridOp grid, int64_t tensorDim, const Sharding &srcSharding, const Sharding &tgtSharding, ShapedType srcUnshardedType, TypedValue< ShapedType > srcShard) override
Try to apply this resharding pattern.
ShapedType shardShapedType(ShapedType shape, GridOp grid, Sharding sharding)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection)
static LogicalResult checkFullyAnnotated(Block &block)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
bool isFullReplication(Sharding sharding)
static LogicalResult partitionBlock(Block &block, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::vector< Sharding > getOperandShardings(Operation &op)
DenseMap< Value, Value > UnshardedToShardedValueMap
static std::vector< Sharding > getResultShardings(Operation &op)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
shard::GridOp getGrid(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
static LogicalResult partitionOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
This trait indicates that a terminator operation is "return-like".