16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
62 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63 if ((
size_t)position >= seenIds.size() || seenIds[position])
65 seenIds[position] =
true;
69 unsigned position = cast<AffineDimExpr>(expr).getPosition();
70 if ((
size_t)position >= seenIds.size() || seenIds[position])
72 seenIds[position] =
true;
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
86 llvm::SmallSet<unsigned, 2> positions;
89 positions.insert((
unsigned)it.index());
98 for (
const auto &v : vec) {
108 FailureOr<std::pair<bool, MeshSharding>>
110 Value val = cast<Value>(result);
112 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
115 return !shardOp.getAnnotateForUsers();
118 if (anyShardedForDef) {
123 auto shardOp = llvm::cast<mesh::ShardOp>(*val.
getUsers().begin());
124 return std::make_pair(
false,
MeshSharding(shardOp.getSharding()));
128 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
131 return shardOp.getAnnotateForUsers();
133 if (anyShardedForUsers) {
136 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
138 shardOps.push_back(shardOp);
141 for (
size_t i = 1; i < shardOps.size(); ++i) {
145 "only support all shard ops have the same mesh sharding attr");
147 return std::make_pair(
true, shardForDef);
152 FailureOr<std::pair<bool, MeshSharding>>
156 return std::make_pair(shardOp.getAnnotateForUsers(),
166 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
171 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
174 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
183 if (numOperands + numResults != maps.size())
187 auto resultType = dyn_cast<RankedTensorType>(result.getType());
190 AffineMap map = maps[numOperands + result.getResultNumber()];
203 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
204 os <<
"print loop types and indexing maps for: \n";
205 getOperation()->print(os);
207 os <<
"loop types: [";
208 for (utils::IteratorType type : getLoopIteratorTypes()) {
209 os << stringifyEnum(type) <<
" ";
212 os <<
"indexing maps: \n";
225 static LogicalResult fillShardingOption(
Operation *op,
230 if ((shardingOption.
mesh && mesh && shardingOption.
mesh != mesh) ||
233 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts on loop iterator "
237 for (
size_t i = 0; i < shardingOption.
shardingArray.size(); ++i) {
242 if (llvm::is_contained(shardingOption.
shardingArray[i], axis)) {
243 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts because mesh axes "
244 << axis <<
" duplicate");
250 shardingOption.
mesh = mesh;
252 shardingOption.
shardingArray[loopIdx].append(meshAxes.begin(),
259 FailureOr<ShardingOption>
263 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
266 if (failed(shardingOp.verifyShardingInterfaceImpl()))
267 return op->
emitOpError() <<
"invalid sharding interface implementation";
269 shardingOp.getLoopIteratorTypes();
274 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
275 bool anyShardingInResultsOrOperands =
false;
282 AffineMap map = maps[numOperands + shardingIt.index()];
283 anyShardingInResultsOrOperands =
true;
293 auto dim = cast<AffineDimExpr>(expr);
294 unsigned index = dim.getPosition();
295 visitedLoopIndices.insert(index);
296 if (failed(fillShardingOption(op, shardingOption,
305 if (!partialAxes.empty()) {
306 if (!partialMeshAxes.empty())
307 return op->
emitOpError() <<
"at most one result with partial axes is "
308 "supported at present";
309 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
312 for (
size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
314 visitedLoopIndices.insert(loopIdx);
325 anyShardingInResultsOrOperands = !shardAttr.
getSplitAxes().empty();
326 AffineMap map = maps[shardingIt.index()];
337 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
339 if (failed(loopIndices))
341 <<
"operand's affine expression is restricted to const_i * "
342 "dim_i + const_j + dim_j + ...";
343 if (loopIndices->empty())
345 if (loopIndices->size() == 1) {
346 unsigned loopIdx = *loopIndices->begin();
347 visitedLoopIndices.insert(loopIdx);
348 if (failed(fillShardingOption(op, shardingOption,
355 if (loopIndices->size() > 1) {
356 bool seenLoopIndices =
false;
357 for (
unsigned loopIdx : *loopIndices) {
358 if (visitedLoopIndices.contains(loopIdx)) {
359 seenLoopIndices =
true;
363 if (!seenLoopIndices)
365 <<
"the operand " << shardingIt.index()
366 <<
" has multiple loop indices in a dimension, but none of "
367 "them could be found in the exactly specified annotation "
368 "of op results or operands.";
374 if (!partialMeshAxes.empty()) {
375 bool anyNonEmptyReductionLoop = llvm::any_of(
377 SmallVector<MeshAxis> &subArray = it.value();
378 int64_t idx = it.index();
379 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
381 if (!anyNonEmptyReductionLoop) {
383 for (
size_t idx = 0; idx < loopTypes.size(); ++idx) {
385 std::ignore = fillShardingOption(op, shardingOption,
nullptr,
386 partialMeshAxes, idx);
392 return op->
emitOpError() <<
"no matched reduction loop found for the "
393 "result's partial type";
397 if (!anyShardingInResultsOrOperands)
398 shardingOption.
empty =
true;
399 return shardingOption;
406 auto resultType = cast<RankedTensorType>(result.
getType());
416 auto dim = cast<AffineDimExpr>(expr);
417 unsigned loopIdx = dim.getPosition();
419 splitAxes[it.index()].append(shardingOption.
shardingArray[loopIdx]);
425 size_t reductionLoopKindsIdx = 0;
426 for (
auto it : llvm::zip(loopTypes, shardingOption.
shardingArray)) {
427 utils::IteratorType iType = std::get<0>(it);
429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430 ++reductionLoopKindsIdx;
431 if (!partialAxes.empty())
432 assert(partialType == curPartialType &&
433 "Only one reduction type is supported");
434 partialType = curPartialType;
436 partialAxes.append(axis);
443 partialAxes, partialType);
449 Value operandValue = opOperand.
get();
450 auto operandType = dyn_cast<RankedTensorType>(operandValue.
getType());
457 if (operandType.getRank() == 0) {
463 int64_t idx = it.index();
465 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
467 if (failed(loopIndices))
470 for (
unsigned loopIdx : *loopIndices) {
473 shardedLoopIndices.push_back(loopIdx);
476 if (shardedLoopIndices.size() > 1)
478 if (shardedLoopIndices.size() == 1) {
479 splitAxes[idx].append(
490 FailureOr<std::vector<MeshSharding>>
493 std::vector<MeshSharding> res;
495 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
497 shardingOp.getLoopIteratorTypes();
499 shardingOp.getReductionLoopIteratorKinds();
504 FailureOr<MeshSharding> shardingAttr =
getSharding(
505 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
506 if (failed(shardingAttr))
508 res.push_back(*shardingAttr);
513 maps[numOperands + result.getResultNumber()],
514 loopTypes, reductionKinds));
532 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
544 FailureOr<MeshSharding> sharding =
546 if (failed(sharding)) {
557 assert(!shardingOption.
empty && shardingOption.
mesh);
559 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
561 shardingOp.getLoopIteratorTypes();
563 shardingOp.getReductionLoopIteratorKinds();
569 if (failed(
addShardOp(b, result, shardingOption,
570 maps[numOperands + result.getResultNumber()],
571 loopTypes, reductionKinds)))
577 if (failed(
addShardOp(b, opOperand, shardingOption,
578 maps[opOperand.getOperandNumber()])))
589 if (isa<RankedTensorType>(value.
getType())) {
596 template <
typename ValueRange,
typename MeshShardingRage>
599 MeshShardingRage &&shardings) {
600 if (std::size(values) != std::size(shardings)) {
604 llvm::zip_equal(std::forward<ValueRange>(values),
605 std::forward<MeshShardingRage>(shardings)),
606 [](
auto valueAndSharding) {
608 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
618 assert(spmdizedOperands.size() == operandShardings.size());
624 builder.
clone(op, spmdizationMap);
630 &meshAxesAssignmentForLoopIterators) {
631 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
632 unsigned loopIteratorIdx = affineDimExpr.
getPosition();
633 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
634 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
635 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
637 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
638 llvm::to_vector(meshAxesAssignmentForTensorAxis);
648 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
649 std::vector<MeshSharding> operatorAndResultShardings;
650 operatorAndResultShardings.reserve(operandShardings.size() +
651 resultShardings.size());
652 llvm::append_range(operatorAndResultShardings, operandShardings);
653 for (
auto [sharding, affineMap] :
654 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
658 for (
auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
659 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
661 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
662 meshAxisAssignmentForLoopIterators);
665 for (
unsigned i = sharding.getSplitAxes().size();
666 i < affineMap.getNumResults(); ++i) {
668 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
673 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
678 return std::move(*axes);
686 for (
auto [loopIteratorType, meshAxisAssignment] :
687 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
688 if (loopIteratorType == utils::IteratorType::reduction &&
689 !meshAxisAssignment.empty()) {
700 for (
auto [loopIteratorType, meshAxisAssignment] :
701 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
702 if (loopIteratorType == utils::IteratorType::reduction) {
703 llvm::append_range(meshAxes, meshAxisAssignment);
717 for (
auto [oldResult, newResult, sharding] :
721 getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding));
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
ArrayRef< MeshAxesAttr > getSplitAxes() const
ArrayRef< MeshAxis > getPartialAxes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
bool isFullReplication(MeshSharding sharding)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
ShardingArray shardingArray