16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
63 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
64 if ((
size_t)position >= seenIds.size() || seenIds[position])
66 seenIds[position] =
true;
70 unsigned position = cast<AffineDimExpr>(expr).getPosition();
71 if ((
size_t)position >= seenIds.size() || seenIds[position])
73 seenIds[position] =
true;
81 static FailureOr<llvm::SmallSet<unsigned, 2>>
87 llvm::SmallSet<unsigned, 2> positions;
90 positions.insert((
unsigned)it.index());
99 for (
const auto &v : vec) {
109 FailureOr<std::pair<bool, MeshSharding>>
111 Value val = cast<Value>(result);
113 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
116 return !shardOp.getAnnotateForUsers();
119 if (anyShardedForDef) {
124 auto shardOp = llvm::cast<mesh::ShardOp>(*val.
getUsers().begin());
125 return std::make_pair(
false,
MeshSharding(shardOp.getSharding()));
129 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
132 return shardOp.getAnnotateForUsers();
134 if (anyShardedForUsers) {
137 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
139 shardOps.push_back(shardOp);
142 for (
size_t i = 1; i < shardOps.size(); ++i) {
146 "only support all shard ops have the same mesh sharding attr");
148 return std::make_pair(
true, shardForDef);
153 FailureOr<std::pair<bool, MeshSharding>>
157 return std::make_pair(shardOp.getAnnotateForUsers(),
167 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
172 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
175 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
184 if (numOperands + numResults != maps.size())
188 auto resultType = dyn_cast<RankedTensorType>(result.getType());
191 AffineMap map = maps[numOperands + result.getResultNumber()];
204 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
205 os <<
"print loop types and indexing maps for: \n";
206 getOperation()->print(os);
208 os <<
"loop types: [";
209 for (utils::IteratorType type : getLoopIteratorTypes()) {
210 os << stringifyEnum(type) <<
" ";
213 os <<
"indexing maps: \n";
226 static LogicalResult fillShardingOption(
Operation *op,
231 if ((shardingOption.
mesh && mesh && shardingOption.
mesh != mesh) ||
234 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts on loop iterator "
238 for (
size_t i = 0; i < shardingOption.
shardingArray.size(); ++i) {
243 if (llvm::is_contained(shardingOption.
shardingArray[i], axis)) {
244 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts because mesh axes "
245 << axis <<
" duplicate");
251 shardingOption.
mesh = mesh;
253 shardingOption.
shardingArray[loopIdx].append(meshAxes.begin(),
260 FailureOr<ShardingOption>
264 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
267 if (failed(shardingOp.verifyShardingInterfaceImpl()))
268 return op->
emitOpError() <<
"invalid sharding interface implementation";
270 shardingOp.getLoopIteratorTypes();
275 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
276 bool anyShardingInResultsOrOperands =
false;
283 AffineMap map = maps[numOperands + shardingIt.index()];
284 anyShardingInResultsOrOperands =
true;
294 auto dim = cast<AffineDimExpr>(expr);
295 unsigned index = dim.getPosition();
296 visitedLoopIndices.insert(index);
297 if (failed(fillShardingOption(op, shardingOption,
306 if (!partialAxes.empty()) {
307 if (!partialMeshAxes.empty())
308 return op->
emitOpError() <<
"at most one result with partial axes is "
309 "supported at present";
310 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
313 for (
size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
315 visitedLoopIndices.insert(loopIdx);
326 anyShardingInResultsOrOperands = !shardAttr.
getSplitAxes().empty();
327 AffineMap map = maps[shardingIt.index()];
338 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
340 if (failed(loopIndices))
342 <<
"operand's affine expression is restricted to const_i * "
343 "dim_i + const_j + dim_j + ...";
344 if (loopIndices->empty())
346 if (loopIndices->size() == 1) {
347 unsigned loopIdx = *loopIndices->begin();
348 visitedLoopIndices.insert(loopIdx);
349 if (failed(fillShardingOption(op, shardingOption,
356 if (loopIndices->size() > 1) {
357 bool seenLoopIndices =
false;
358 for (
unsigned loopIdx : *loopIndices) {
359 if (visitedLoopIndices.contains(loopIdx)) {
360 seenLoopIndices =
true;
364 if (!seenLoopIndices)
366 <<
"the operand " << shardingIt.index()
367 <<
" has multiple loop indices in a dimension, but none of "
368 "them could be found in the exactly specified annotation "
369 "of op results or operands.";
375 if (!partialMeshAxes.empty()) {
376 bool anyNonEmptyReductionLoop = llvm::any_of(
378 SmallVector<MeshAxis> &subArray = it.value();
379 int64_t idx = it.index();
380 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
382 if (!anyNonEmptyReductionLoop) {
384 for (
size_t idx = 0; idx < loopTypes.size(); ++idx) {
386 std::ignore = fillShardingOption(op, shardingOption,
nullptr,
387 partialMeshAxes, idx);
393 return op->
emitOpError() <<
"no matched reduction loop found for the "
394 "result's partial type";
398 if (!anyShardingInResultsOrOperands)
399 shardingOption.
empty =
true;
400 return shardingOption;
407 auto resultType = cast<RankedTensorType>(result.
getType());
416 auto dim = cast<AffineDimExpr>(expr);
417 unsigned loopIdx = dim.getPosition();
419 splitAxes[it.index()].append(shardingOption.
shardingArray[loopIdx]);
425 size_t reductionLoopKindsIdx = 0;
426 for (
auto it : llvm::zip(loopTypes, shardingOption.
shardingArray)) {
427 utils::IteratorType iType = std::get<0>(it);
429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430 ++reductionLoopKindsIdx;
431 if (!partialAxes.empty())
432 assert(partialType == curPartialType &&
433 "Only one reduction type is supported");
434 partialType = curPartialType;
436 partialAxes.append(axis);
443 partialAxes, partialType);
449 Value operandValue = opOperand.
get();
450 auto operandType = dyn_cast<RankedTensorType>(operandValue.
getType());
457 if (operandType.getRank() == 0) {
463 int64_t idx = it.index();
465 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
467 if (failed(loopIndices))
470 for (
unsigned loopIdx : *loopIndices) {
473 shardedLoopIndices.push_back(loopIdx);
476 if (shardedLoopIndices.size() > 1)
478 if (shardedLoopIndices.size() == 1) {
479 splitAxes[idx].append(
490 FailureOr<std::vector<MeshSharding>>
493 std::vector<MeshSharding> res;
495 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
497 shardingOp.getLoopIteratorTypes();
499 shardingOp.getReductionLoopIteratorKinds();
504 FailureOr<MeshSharding> shardingAttr =
getSharding(
505 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
506 if (failed(shardingAttr))
508 res.push_back(*shardingAttr);
513 maps[numOperands + result.getResultNumber()],
514 loopTypes, reductionKinds));
532 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
544 FailureOr<MeshSharding> sharding =
546 if (failed(sharding)) {
557 assert(!shardingOption.
empty && shardingOption.
mesh);
559 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
561 shardingOp.getLoopIteratorTypes();
563 shardingOp.getReductionLoopIteratorKinds();
569 if (failed(
addShardOp(b, result, shardingOption,
570 maps[numOperands + result.getResultNumber()],
571 loopTypes, reductionKinds)))
577 if (failed(
addShardOp(b, opOperand, shardingOption,
578 maps[opOperand.getOperandNumber()])))
589 if (isa<RankedTensorType>(value.
getType())) {
596 template <
typename ValueRange,
typename MeshShardingRage>
599 MeshShardingRage &&shardings) {
600 if (std::size(values) != std::size(shardings)) {
604 llvm::zip_equal(std::forward<ValueRange>(values),
605 std::forward<MeshShardingRage>(shardings)),
606 [](
auto valueAndSharding) {
608 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
618 assert(spmdizedOperands.size() == operandShardings.size());
624 builder.
clone(op, spmdizationMap);
630 &meshAxesAssignmentForLoopIterators) {
631 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
632 unsigned loopIteratorIdx = affineDimExpr.
getPosition();
633 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
634 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
635 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
637 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
638 llvm::to_vector(meshAxesAssignmentForTensorAxis);
648 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
649 std::vector<MeshSharding> operatorAndResultShardings;
650 operatorAndResultShardings.reserve(operandShardings.size() +
651 resultShardings.size());
652 llvm::append_range(operatorAndResultShardings, operandShardings);
653 for (
auto [sharding, affineMap] :
654 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
658 for (
auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
659 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
661 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
662 meshAxisAssignmentForLoopIterators);
665 for (
unsigned i = sharding.getSplitAxes().size();
666 i < affineMap.getNumResults(); ++i) {
668 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
673 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
678 return std::move(*axes);
686 for (
auto [loopIteratorType, meshAxisAssignment] :
687 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
688 if (loopIteratorType == utils::IteratorType::reduction &&
689 !meshAxisAssignment.empty()) {
700 for (
auto [loopIteratorType, meshAxisAssignment] :
701 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
702 if (loopIteratorType == utils::IteratorType::reduction) {
703 llvm::append_range(meshAxes, meshAxisAssignment);
717 for (
auto [oldResult, newResult, sharding] :
721 getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
ArrayRef< MeshAxesAttr > getSplitAxes() const
ArrayRef< MeshAxis > getPartialAxes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
bool isFullReplication(MeshSharding sharding)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
ShardingArray shardingArray