31 #include "llvm/ADT/APInt.h"
32 #include "llvm/ADT/DenseSet.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/Support/Casting.h"
39 #include <type_traits>
43 template <
typename SourceAxes,
typename TargetAxes>
45 const TargetAxes &targetAxes) {
46 return llvm::all_of(targetAxes, [&sourceAxes](
auto &targetAxis) {
47 return sourceAxes.contains(targetAxis);
64 return {sourceShard, sourceSharding};
69 using Axis = std::decay_t<decltype(sourceSharding.
getPartialAxes().front())>;
70 using AxisSet = llvm::SmallDenseSet<Axis>;
71 AxisSet sourceShardingPartialAxesSet(sourceSharding.
getPartialAxes().begin(),
73 AxisSet targetShardingPartialAxesSet(targetSharding.
getPartialAxes().begin(),
76 targetShardingPartialAxesSet));
78 llvm::copy_if(sourceShardingPartialAxesSet,
79 std::back_inserter(allReduceMeshAxes),
80 [&targetShardingPartialAxesSet](Axis a) {
81 return !targetShardingPartialAxesSet.contains(a);
83 if (allReduceMeshAxes.empty()) {
84 return {sourceShard, sourceSharding};
90 .
create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
92 allReduceMeshAxes, sourceShard,
97 llvm::copy_if(sourceShardingPartialAxesSet,
98 std::back_inserter(allReduceMeshAxes),
99 [&targetShardingPartialAxesSet](Axis a) {
100 return targetShardingPartialAxesSet.contains(a);
105 return {resultValue, resultSharding};
110 int64_t splitTensorAxis,
114 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
118 auto targetSplitAxes =
119 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
120 targetSplitAxes.push_back(splitMeshAxis);
121 targetShardingSplitAxes[splitTensorAxis] =
124 sourceSharding.
getMeshAttr(), targetShardingSplitAxes,
135 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
138 .
create<AllSliceOp>(sourceShard, mesh,
143 builder.
getContext(), sourceSharding, splitTensorAxis, splitMeshAxis);
144 return {targetShard, targetSharding};
152 static std::optional<std::tuple<int64_t, MeshAxis>>
155 for (
size_t tensorAxis = 0; tensorAxis < targetSharding.
getSplitAxes().size();
157 if (sourceSharding.
getSplitAxes().size() > tensorAxis) {
158 if (sourceSharding.
getSplitAxes()[tensorAxis].size() + 1 !=
168 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
173 if (targetSharding.
getSplitAxes()[tensorAxis].size() != 1) {
177 return std::make_tuple(
179 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
184 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
191 auto [tensorAxis, meshAxis] = detectRes.value();
193 tensorAxis, meshAxis);
202 static std::optional<std::tuple<int64_t, MeshAxis>>
205 for (
size_t tensorAxis = 0; tensorAxis < sourceSharding.
getSplitAxes().size();
207 if (targetSharding.
getSplitAxes().size() > tensorAxis) {
216 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().end() -
218 targetSharding.
getSplitAxes()[tensorAxis].asArrayRef()))
221 if (sourceSharding.
getSplitAxes()[tensorAxis].size() != 1)
224 return std::make_tuple(
226 sourceSharding.
getSplitAxes()[tensorAxis].asArrayRef().back());
233 int64_t splitTensorAxis) {
236 assert(
static_cast<int64_t
>(targetShardingSplitAxes.size()) >
238 auto targetSplitAxes =
239 llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
241 targetSplitAxes.pop_back();
242 targetShardingSplitAxes[splitTensorAxis] =
245 sourceSharding.
getMeshAttr(), targetShardingSplitAxes,
250 ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
252 targetShape[splitTensorAxis] =
254 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
260 ShapedType sourceUnshardedShape,
262 int64_t splitTensorAxis,
MeshAxis splitMeshAxis) {
269 sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis);
270 Value allGatherResult = builder.
create<AllGatherOp>(
272 allGatherResultShape.getElementType()),
274 APInt(64, splitTensorAxis));
275 ShapedType targetShape =
278 builder.
create<tensor::CastOp>(targetShape, allGatherResult).getResult());
279 return {targetShard, targetSharding};
282 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
286 ShapedType sourceUnshardedShape,
290 auto [tensorAxis, meshAxis] = detectRes.value();
292 sourceUnshardedShape, sourceShard, mesh,
293 tensorAxis, meshAxis);
304 static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
307 for (
size_t sourceTensorAxis = 0;
308 sourceTensorAxis < sourceSharding.
getSplitAxes().size();
309 ++sourceTensorAxis) {
310 for (
size_t targetTensorAxis = 0;
311 targetTensorAxis < targetSharding.
getSplitAxes().size();
312 ++targetTensorAxis) {
313 if (sourceTensorAxis == targetTensorAxis)
315 if (sourceSharding.
getSplitAxes()[sourceTensorAxis].empty() ||
316 targetSharding.
getSplitAxes()[targetTensorAxis].empty() ||
317 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
323 llvm::make_range(sourceSharding.
getSplitAxes()[sourceTensorAxis]
330 llvm::make_range(targetSharding.
getSplitAxes()[targetTensorAxis]
338 return std::make_tuple(
339 sourceTensorAxis, targetTensorAxis,
340 sourceSharding.
getSplitAxes()[sourceTensorAxis].asArrayRef().back());
348 int64_t sourceTensorAxis,
349 int64_t targetTensorAxis) {
352 while (
static_cast<int64_t
>(targetShardingSplitAxes.size()) <=
357 auto sourceSplitAxes =
358 llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
359 assert(!sourceSplitAxes.empty());
360 auto meshAxis = sourceSplitAxes.back();
361 sourceSplitAxes.pop_back();
362 targetShardingSplitAxes[sourceTensorAxis] =
365 auto targetSplitAxes =
366 llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
367 targetSplitAxes.push_back(meshAxis);
368 targetShardingSplitAxes[targetTensorAxis] =
372 sourceSharding.
getMeshAttr(), targetShardingSplitAxes,
378 int64_t sourceTensorAxis,
379 int64_t targetTensorAxis) {
381 targetShape[sourceTensorAxis] =
383 targetShape[targetTensorAxis] =
385 return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
391 ShapedType sourceUnshardedShape,
393 int64_t sourceTensorAxis,
394 int64_t targetTensorAxis,
MeshAxis meshAxis) {
399 ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
401 sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis,
405 allToAllResultShape.getElementType()),
407 APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
408 ShapedType targetShape =
411 builder.
create<tensor::CastOp>(targetShape, allToAllResult).getResult());
412 return {targetShard, targetSharding};
415 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
419 ShapedType sourceUnshardedShape,
423 auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
425 builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
426 sourceTensorAxis, targetTensorAxis, meshAxis);
436 static std::optional<std::tuple<TypedValue<ShapedType>,
MeshSharding>>
440 ShapedType sourceUnshardedShape,
455 assert(srcHaloSizes.empty() || srcHaloSizes.size() == tgtHaloSizes.size());
456 assert(((srcHaloSizes.empty() || !ShapedType::isDynamicShape(srcHaloSizes)) &&
457 !ShapedType::isDynamicShape(tgtHaloSizes) &&
458 sourceShard.getType().hasStaticShape()) &&
459 "dynamic shapes/halos are not supported yet for mesh-spmdization");
460 auto rank = sourceShard.getType().getRank();
463 strides(rank, 1), outShape(sourceShard.getType().getShape()),
464 coreShape(sourceShard.getType().getShape());
468 for (
auto i = 0u; i < rank; ++i) {
469 if (i < splitAxes.size() && !splitAxes[i].empty()) {
470 if (!srcHaloSizes.empty()) {
471 coreShape[i] -= srcHaloSizes[i * 2] + srcHaloSizes[i * 2 + 1];
472 srcCoreOffs[i] = srcHaloSizes[i * 2];
474 tgtCoreOffs[i] = tgtHaloSizes[i * 2];
476 coreShape[i] + tgtHaloSizes[i * 2] + tgtHaloSizes[i * 2 + 1];
482 auto initVal = builder.
create<tensor::EmptyOp>(
483 sourceShard.getLoc(), outShape, sourceShard.getType().getElementType());
484 auto core = builder.
create<tensor::ExtractSliceOp>(
485 sourceShard.getLoc(),
487 sourceShard, noVals, noVals, noVals, srcCoreOffs, coreShape, strides);
488 auto initOprnd = builder.
create<tensor::InsertSliceOp>(
489 sourceShard.getLoc(), core, initVal, noVals, noVals, noVals, tgtCoreOffs,
493 auto updateHaloResult =
496 sourceShard.getLoc(),
498 sourceShard.getType().getElementType()),
499 initOprnd, mesh.getSymName(),
517 assert(sourceShard.getType() ==
518 shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
519 [[maybe_unused]] ShapedType targetShardType =
521 assert(sourceShard.getType().getRank() == targetShardType.getRank());
522 assert(mesh.getRank() == 1 &&
"Only 1D meshes are currently supported.");
524 auto [reducedSourceShard, reducedSourceSharding] =
528 if (reducedSourceSharding == targetSharding) {
529 return reducedSourceShard;
534 if (reducedSourceSharding.getStaticShardedDimsOffsets().empty() &&
536 reducedSourceSharding.getStaticHaloSizes().empty() &&
539 builder, mesh, reducedSourceSharding, targetSharding,
540 sourceUnshardedValue.getType(), reducedSourceShard)) {
541 std::tie(targetShard, actualTargetSharding) = tryRes.value();
543 builder, mesh, reducedSourceSharding, targetSharding,
544 reducedSourceShard)) {
545 std::tie(targetShard, actualTargetSharding) = tryRes.value();
547 builder, mesh, reducedSourceSharding, targetSharding,
548 sourceUnshardedValue.getType(), reducedSourceShard)) {
549 std::tie(targetShard, actualTargetSharding) = tryRes.value();
552 assert(targetShard &&
"Did not find any pattern to apply.");
553 assert(actualTargetSharding == targetSharding);
554 assert(targetShard.getType() == targetShardType);
572 builder, mesh, sourceSharding, targetSharding,
573 sourceUnshardedValue.getType(), sourceShard)) {
574 return std::get<0>(tryRes.value());
581 sourceUnshardedValue, sourceShard);
587 assert(source.getResult() == target.getSrc());
588 auto sourceSharding = source.getSharding();
589 auto targetSharding = target.getSharding();
591 return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding,
601 assert(srcMesh && srcMesh ==
getMesh(target, symbolTableCollection));
602 return reshard(builder, srcMesh, source, target, sourceShardValue);
606 registry.
insert<mesh::MeshDialect, tensor::TensorDialect>();
609 #define GEN_PASS_DEF_SPMDIZATION
610 #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc"
624 auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg);
625 if (!rankedTensorArg || rankedTensorArg.getType().getRank() == 0) {
626 return arg.getType();
629 assert(rankedTensorArg.hasOneUse());
631 ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp);
635 shardOp.getSharding()));
645 ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op);
646 if (!shardingInterface) {
650 resultShardings, spmdizationMap,
651 symbolTableCollection, builder);
653 if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings,
654 resultShardings, spmdizationMap,
655 symbolTableCollection, builder))) {
661 return spmdizationMap.contains(result);
670 std::vector<MeshSharding> res;
672 llvm::transform(op.
getOperands(), std::back_inserter(res), [](
Value operand) {
673 TypedValue<RankedTensorType> rankedTensor =
674 dyn_cast<TypedValue<RankedTensorType>>(operand);
675 if (!rankedTensor || rankedTensor.getType().getRank() == 0) {
676 return MeshSharding();
681 ShardOp shardOp = llvm::cast<ShardOp>(definingOp);
690 std::vector<MeshSharding> res;
694 if (!result.hasOneUse() || result.use_empty()) {
695 return MeshSharding();
703 ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp);
707 if (rankedTensor.getType().getRank() == 0) {
712 if (
auto sharding = operand.getDefiningOp<ShardingOp>()) {
726 Value targetSpmdValue;
731 dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp());
733 targetSpmdValue = spmdizationMap.
lookup(shardOp.getSrc());
737 cast<TypedValue<ShapedType>>(spmdizationMap.
lookup(srcShardOp));
738 targetSpmdValue =
reshard(builder, srcShardOp, shardOp, srcSpmdValue,
739 symbolTableCollection);
742 assert(!spmdizationMap.
contains(shardOp.getResult()));
743 spmdizationMap.
map(shardOp.getResult(), targetSpmdValue);
751 if (isa<ShardingOp>(op)) {
754 if (
auto getShardingOp = dyn_cast<GetShardingOp>(op)) {
755 auto shardOp = getShardingOp.getSource().getDefiningOp<ShardOp>();
757 return op.
emitError(
"expected a shard op as source of get_sharding");
759 auto newSharding = builder.
clone(*shardOp.getSharding().getDefiningOp());
760 spmdizationMap.
map(op.
getResult(0), newSharding->getResult(0));
764 ShardOp shardOp = llvm::dyn_cast<ShardOp>(op);
771 llvm::transform(op.
getOperands(), std::back_inserter(spmdizedOperands),
772 [&spmdizationMap](
Value operand) {
773 assert(spmdizationMap.contains(operand));
774 return spmdizationMap.lookup(operand);
778 symbolTableCollection, builder);
786 llvm::transform(block.
getArguments(), std::back_inserter(argLocations),
791 for (
auto [unshardedBlockArg, spmdizedBlockArg] :
793 spmdizationMap.
map(unshardedBlockArg, spmdizedBlockArg);
816 for (
Block &b : op.getBlocks()) {
817 if (llvm::any_of(b.getOperations(),
818 [](
Operation &op) { return isa<ShardOp>(op); })) {
819 originalBlocks.push_back(&b);
823 for (
Block *block : originalBlocks) {
824 if (failed(
spmdizeBlock(*block, spmdizationMap, symbolTableCollection,
830 for (
Block *block : originalBlocks) {
837 for (
Block &block : op.getFunctionBody()) {
843 returnOp = &block.back();
849 op->getContext(), op.getFunctionBody().front().getArgumentTypes(),
858 struct Spmdization :
public impl::SpmdizationBase<Spmdization> {
859 void runOnOperation()
override {
863 symbolTableCollection))) {
864 return signalPassFailure();
868 void getDependentDialects(DialectRegistry ®istry)
const override {
870 registry.insert<mesh::MeshDialect>();
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
OpListType & getOperations()
BlockArgListType getArguments()
MLIRContext * getContext() const
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
bool contains(T from) const
Checks to see if a mapping for 'from' exists.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
user_range getUsers()
Returns a range of all users.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
::mlir::FlatSymbolRefAttr getMeshAttr() const
bool equalHaloSizes(const MeshSharding &rhs) const
ArrayRef< MeshAxesAttr > getSplitAxes() const
ReductionKind getPartialType() const
ArrayRef< MeshAxis > getPartialAxes() const
ArrayRef< Value > getDynamicHaloSizes() const
ArrayRef< int64_t > getStaticHaloSizes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
mesh::MeshSharding MeshSharding
static LogicalResult spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection)
static std::tuple< TypedValue< ShapedType >, MeshSharding > handlePartialAxesDuringResharding(OpBuilder &builder, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
SmallVector< Type > shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection)
static ShapedType allGatherResultShapeInUnsplitLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
static std::optional< std::tuple< int64_t, int64_t, MeshAxis > > detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, int64_t splitCount, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceShard)
static std::vector< MeshSharding > getResultShardings(Operation &op)
static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, const TargetAxes &targetAxes)
static std::optional< std::tuple< int64_t, MeshAxis > > detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis)
static std::tuple< TypedValue< ShapedType >, MeshSharding > unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
TypedValue< ShapedType > reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, ShardOp target, TypedValue< ShapedType > sourceShardValue)
bool isFullReplication(MeshSharding sharding)
static TypedValue< ShapedType > reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
void reshardingRegisterDependentDialects(DialectRegistry ®istry)
static std::tuple< TypedValue< ShapedType >, MeshSharding > splitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshSharding sourceSharding, TypedValue< ShapedType > sourceShard, MeshOp mesh, int64_t splitTensorAxis, MeshAxis splitMeshAxis)
static std::vector< MeshSharding > getOperandShardings(Operation &op)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeOperation(Operation &op, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, MeshSharding sourceSharding, int64_t sourceTensorAxis, int64_t targetTensorAxis)
static std::optional< std::tuple< TypedValue< ShapedType >, MeshSharding > > tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard)
static std::tuple< TypedValue< ShapedType >, MeshSharding > moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue< ShapedType > sourceShard, int64_t sourceTensorAxis, int64_t targetTensorAxis, MeshAxis meshAxis)
TypedValue< ShapedType > reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, MeshSharding sourceSharding, MeshSharding targetSharding, TypedValue< ShapedType > sourceUnshardedValue, TypedValue< ShapedType > sourceShard)
static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder)
static std::optional< std::tuple< int64_t, MeshAxis > > detectSplitLastAxisInResharding(MeshSharding sourceSharding, MeshSharding targetSharding)
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This trait indicates that a terminator operation is "return-like".