16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
62 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63 if ((
size_t)position >= seenIds.size() || seenIds[position])
65 seenIds[position] =
true;
69 unsigned position = cast<AffineDimExpr>(expr).getPosition();
70 if ((
size_t)position >= seenIds.size() || seenIds[position])
72 seenIds[position] =
true;
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
86 llvm::SmallSet<unsigned, 2> positions;
89 positions.insert((
unsigned)it.index());
98 for (
const auto &v : vec) {
108 FailureOr<std::pair<bool, MeshSharding>>
110 Value val = cast<Value>(result);
112 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
115 return !shardOp.getAnnotateForUsers();
118 if (anyShardedForDef) {
123 auto shardOp = llvm::cast<mesh::ShardOp>(*val.
getUsers().begin());
124 return std::make_pair(
false,
MeshSharding(shardOp.getSharding()));
128 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
131 return shardOp.getAnnotateForUsers();
133 if (anyShardedForUsers) {
136 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
138 shardOps.push_back(shardOp);
141 for (
size_t i = 1; i < shardOps.size(); ++i) {
145 "only support all shard ops have the same mesh sharding attr");
147 return std::make_pair(
true, shardForDef);
152 FailureOr<std::pair<bool, MeshSharding>>
156 return std::make_pair(shardOp.getAnnotateForUsers(),
166 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
171 if (!llvm::isa<RankedTensorType>(type))
174 if (!llvm::isa<RankedTensorType>(type))
179 if (loopTypes.empty())
188 if (numOperands + numResults != maps.size())
192 auto resultType = dyn_cast<RankedTensorType>(result.getType());
195 AffineMap map = maps[numOperands + result.getResultNumber()];
208 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
209 os <<
"print loop types and indexing maps for: \n";
210 getOperation()->print(os);
212 os <<
"loop types: [";
213 for (utils::IteratorType type : getLoopIteratorTypes()) {
214 os << stringifyEnum(type) <<
" ";
217 os <<
"indexing maps: \n";
230 static LogicalResult fillShardingOption(
Operation *op,
235 if ((shardingOption.
mesh && mesh && shardingOption.
mesh != mesh) ||
238 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts on loop iterator "
242 for (
size_t i = 0; i < shardingOption.
shardingArray.size(); ++i) {
247 if (llvm::is_contained(shardingOption.
shardingArray[i], axis)) {
248 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts because mesh axes "
249 << axis <<
" duplicate");
255 shardingOption.
mesh = mesh;
257 shardingOption.
shardingArray[loopIdx].append(meshAxes.begin(),
264 FailureOr<ShardingOption>
268 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
271 if (failed(shardingOp.verifyShardingInterfaceImpl()))
272 return op->
emitOpError() <<
"invalid sharding interface implementation";
274 shardingOp.getLoopIteratorTypes();
279 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
280 bool anyShardingInResultsOrOperands =
false;
287 AffineMap map = maps[numOperands + shardingIt.index()];
288 anyShardingInResultsOrOperands =
true;
295 auto dim = cast<AffineDimExpr>(expr);
296 unsigned index = dim.getPosition();
297 visitedLoopIndices.insert(index);
298 if (failed(fillShardingOption(op, shardingOption, shardAttr.
getMeshAttr(),
306 if (!partialAxes.empty()) {
307 if (!partialMeshAxes.empty())
308 return op->
emitOpError() <<
"at most one result with partial axes is "
309 "supported at present";
310 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
313 for (
size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
315 visitedLoopIndices.insert(loopIdx);
326 anyShardingInResultsOrOperands =
true;
327 AffineMap map = maps[shardingIt.index()];
338 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
340 if (failed(loopIndices))
342 <<
"operand's affine expression is restricted to const_i * "
343 "dim_i + const_j + dim_j + ...";
344 if (loopIndices->empty())
346 if (loopIndices->size() == 1) {
347 unsigned loopIdx = *loopIndices->begin();
348 visitedLoopIndices.insert(loopIdx);
349 if (failed(fillShardingOption(op, shardingOption,
356 if (loopIndices->size() > 1) {
357 bool seenLoopIndices =
false;
358 for (
unsigned loopIdx : *loopIndices) {
359 if (visitedLoopIndices.contains(loopIdx)) {
360 seenLoopIndices =
true;
364 if (!seenLoopIndices)
366 <<
"the operand " << shardingIt.index()
367 <<
" has multiple loop indices in a dimension, but none of "
368 "them could be found in the exactly specified annotation "
369 "of op results or operands.";
375 if (!partialMeshAxes.empty()) {
376 bool anyNonEmptyReductionLoop = llvm::any_of(
378 SmallVector<MeshAxis> &subArray = it.value();
379 int64_t idx = it.index();
380 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
382 if (!anyNonEmptyReductionLoop) {
384 for (
size_t idx = 0; idx < loopTypes.size(); ++idx) {
386 std::ignore = fillShardingOption(op, shardingOption,
nullptr,
387 partialMeshAxes, idx);
393 return op->
emitOpError() <<
"no matched reduction loop found for the "
394 "result's partial type";
398 if (!anyShardingInResultsOrOperands)
399 shardingOption.
empty =
true;
400 return shardingOption;
407 auto resultType = cast<RankedTensorType>(result.
getType());
417 auto dim = cast<AffineDimExpr>(expr);
418 unsigned loopIdx = dim.getPosition();
420 splitAxes[it.index()].append(shardingOption.
shardingArray[loopIdx]);
426 size_t reductionLoopKindsIdx = 0;
427 for (
auto it : llvm::zip(loopTypes, shardingOption.
shardingArray)) {
428 utils::IteratorType iType = std::get<0>(it);
430 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
431 ++reductionLoopKindsIdx;
432 if (!partialAxes.empty())
433 assert(partialType == curPartialType &&
434 "Only one reduction type is supported");
435 partialType = curPartialType;
437 partialAxes.append(axis);
444 partialAxes, partialType);
450 Value operandValue = opOperand.
get();
451 auto operandType = cast<RankedTensorType>(operandValue.
getType());
455 int64_t idx = it.index();
457 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
459 if (failed(loopIndices))
462 for (
unsigned loopIdx : *loopIndices) {
465 shardedLoopIndices.push_back(loopIdx);
468 if (shardedLoopIndices.size() > 1)
470 if (shardedLoopIndices.size() == 1) {
471 splitAxes[idx].append(
482 FailureOr<std::vector<MeshSharding>>
485 std::vector<MeshSharding> res;
487 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
489 shardingOp.getLoopIteratorTypes();
491 shardingOp.getReductionLoopIteratorKinds();
496 FailureOr<MeshSharding> shardingAttr =
getSharding(
497 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
498 if (failed(shardingAttr))
500 res.push_back(*shardingAttr);
505 maps[numOperands + result.getResultNumber()],
506 loopTypes, reductionKinds));
524 getSharding(result, shardingOption, map, loopTypes, reductionLoopKinds);
536 FailureOr<MeshSharding> sharding =
538 if (failed(sharding)) {
549 assert(!shardingOption.
empty && shardingOption.
mesh);
551 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
553 shardingOp.getLoopIteratorTypes();
555 shardingOp.getReductionLoopIteratorKinds();
561 if (failed(
addShardOp(b, result, shardingOption,
562 maps[numOperands + result.getResultNumber()],
563 loopTypes, reductionKinds)))
569 if (failed(
addShardOp(b, opOperand, shardingOption,
570 maps[opOperand.getOperandNumber()])))
581 if (isa<RankedTensorType>(value.
getType())) {
588 template <
typename ValueRange,
typename MeshShardingRage>
591 MeshShardingRage &&shardings) {
592 if (std::size(values) != std::size(shardings)) {
596 llvm::zip_equal(std::forward<ValueRange>(values),
597 std::forward<MeshShardingRage>(shardings)),
598 [](
auto valueAndSharding) {
600 std::get<0>(valueAndSharding), std::get<1>(valueAndSharding));
610 assert(spmdizedOperands.size() == operandShardings.size());
616 builder.
clone(op, spmdizationMap);
622 &meshAxesAssignmentForLoopIterators) {
623 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
624 unsigned loopIteratorIdx = affineDimExpr.
getPosition();
625 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
626 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
627 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
629 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
630 llvm::to_vector(meshAxesAssignmentForTensorAxis);
640 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
641 std::vector<MeshSharding> operatorAndResultShardings;
642 operatorAndResultShardings.reserve(operandShardings.size() +
643 resultShardings.size());
644 llvm::append_range(operatorAndResultShardings, operandShardings);
645 for (
auto [sharding, affineMap] :
646 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
650 for (
auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
651 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
653 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
654 meshAxisAssignmentForLoopIterators);
657 for (
unsigned i = sharding.getSplitAxes().size();
658 i < affineMap.getNumResults(); ++i) {
660 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
665 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
670 return std::move(*axes);
678 for (
auto [loopIteratorType, meshAxisAssignment] :
679 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
680 if (loopIteratorType == utils::IteratorType::reduction &&
681 !meshAxisAssignment.empty()) {
692 for (
auto [loopIteratorType, meshAxisAssignment] :
693 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
694 if (loopIteratorType == utils::IteratorType::reduction) {
695 llvm::append_range(meshAxes, meshAxisAssignment);
709 for (
auto [oldResult, newResult, sharding] :
713 getMesh(&op, sharding.getMeshAttr(), symbolTable), sharding));
SmallVector< MeshAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingRage &&shardings)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshSharding sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
::mlir::FlatSymbolRefAttr getMeshAttr() const
ArrayRef< MeshAxesAttr > getSplitAxes() const
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
ArrayRef< MeshAxis > getPartialAxes() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
FailureOr< std::vector< MeshSharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
bool isReductionLoop(utils::IteratorType iType)
bool isFullReplication(MeshSharding sharding)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
FailureOr< std::pair< bool, MeshSharding > > getMeshSharding(OpResult result)
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
ShardingArray shardingArray