16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
62 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63 if ((
size_t)position >= seenIds.size() || seenIds[position])
65 seenIds[position] =
true;
69 unsigned position = cast<AffineDimExpr>(expr).getPosition();
70 if ((
size_t)position >= seenIds.size() || seenIds[position])
72 seenIds[position] =
true;
80 static FailureOr<llvm::SmallSet<unsigned, 2>>
86 llvm::SmallSet<unsigned, 2> positions;
89 positions.insert((
unsigned)it.index());
98 FailureOr<std::pair<bool, MeshShardingAttr>>
100 Value val = cast<Value>(result);
102 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
105 return !shardOp.getAnnotateForUsers();
108 if (anyShardedForDef) {
113 auto shardOp = llvm::cast<mesh::ShardOp>(*val.
getUsers().begin());
114 return std::make_pair(
false, shardOp.getShard());
118 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
121 return shardOp.getAnnotateForUsers();
123 if (anyShardedForUsers) {
126 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
128 shardOps.push_back(shardOp);
131 for (
size_t i = 1; i < shardOps.size(); ++i) {
134 assert(shardOps[i].getShard() == shardForDef &&
135 "only support all shard ops have the same mesh sharding attr");
137 return std::make_pair(
true, shardForDef);
142 FailureOr<std::pair<bool, MeshShardingAttr>>
146 return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
155 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
160 if (!llvm::isa<RankedTensorType>(type))
163 if (!llvm::isa<RankedTensorType>(type))
168 if (loopTypes.empty())
177 if (numOperands + numResults != maps.size())
181 auto resultType = dyn_cast<RankedTensorType>(result.getType());
184 AffineMap map = maps[numOperands + result.getResultNumber()];
197 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
198 os <<
"print loop types and indexing maps for: \n";
199 getOperation()->print(os);
201 os <<
"loop types: [";
202 for (utils::IteratorType type : getLoopIteratorTypes()) {
203 os << stringifyEnum(type) <<
" ";
206 os <<
"indexing maps: \n";
219 static LogicalResult fillShardingOption(
Operation *op,
224 if ((shardingOption.
mesh && mesh && shardingOption.
mesh != mesh) ||
227 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts on loop iterator "
231 for (
size_t i = 0; i < shardingOption.
shardingArray.size(); ++i) {
236 if (llvm::is_contained(shardingOption.
shardingArray[i], axis)) {
237 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts because mesh axes "
238 << axis <<
" duplicate");
244 shardingOption.
mesh = mesh;
246 shardingOption.
shardingArray[loopIdx].append(meshAxes.begin(),
256 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
259 if (failed(shardingOp.verifyShardingInterfaceImpl()))
260 return op->
emitOpError() <<
"invalid sharding interface implementation";
262 shardingOp.getLoopIteratorTypes();
267 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
268 bool anyShardingInResultsOrOperands =
false;
275 AffineMap map = maps[numOperands + shardingIt.index()];
276 anyShardingInResultsOrOperands =
true;
280 for (
auto it : llvm::zip(map.
getResults(), shardAttr.getSplitAxes())) {
283 auto dim = cast<AffineDimExpr>(expr);
284 unsigned index = dim.getPosition();
285 visitedLoopIndices.insert(index);
286 if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
294 if (!partialAxes.empty()) {
295 if (!partialMeshAxes.empty())
296 return op->
emitOpError() <<
"at most one result with partial axes is "
297 "supported at present";
298 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
301 for (
size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
303 visitedLoopIndices.insert(loopIdx);
314 anyShardingInResultsOrOperands =
true;
315 AffineMap map = maps[shardingIt.index()];
323 for (
auto it : llvm::zip(map.
getResults(), shardAttr.getSplitAxes())) {
326 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
328 if (failed(loopIndices))
330 <<
"operand's affine expression is restricted to const_i * "
331 "dim_i + const_j + dim_j + ...";
332 if (loopIndices->empty())
334 if (loopIndices->size() == 1) {
335 unsigned loopIdx = *loopIndices->begin();
336 visitedLoopIndices.insert(loopIdx);
337 if (failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
344 if (loopIndices->size() > 1) {
345 bool seenLoopIndices =
false;
346 for (
unsigned loopIdx : *loopIndices) {
347 if (visitedLoopIndices.contains(loopIdx)) {
348 seenLoopIndices =
true;
352 if (!seenLoopIndices)
354 <<
"the operand " << shardingIt.index()
355 <<
" has multiple loop indices in a dimension, but none of "
356 "them could be found in the exactly specified annotation "
357 "of op results or operands.";
363 if (!partialMeshAxes.empty()) {
364 bool anyNonEmptyReductionLoop = llvm::any_of(
366 SmallVector<MeshAxis> &subArray = it.value();
367 int64_t idx = it.index();
368 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
370 if (!anyNonEmptyReductionLoop) {
372 for (
size_t idx = 0; idx < loopTypes.size(); ++idx) {
374 std::ignore = fillShardingOption(op, shardingOption,
nullptr,
375 partialMeshAxes, idx);
381 return op->
emitOpError() <<
"no matched reduction loop found for the "
382 "result's partial type";
386 if (!anyShardingInResultsOrOperands)
387 shardingOption.
empty =
true;
388 return shardingOption;
396 auto resultType = cast<RankedTensorType>(result.
getType());
405 auto dim = cast<AffineDimExpr>(expr);
406 unsigned loopIdx = dim.getPosition();
408 splitAxes[it.index()].append(shardingOption.
shardingArray[loopIdx]);
414 size_t reductionLoopKindsIdx = 0;
415 for (
auto it : llvm::zip(loopTypes, shardingOption.
shardingArray)) {
416 utils::IteratorType iType = std::get<0>(it);
418 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
419 ++reductionLoopKindsIdx;
420 if (!partialAxes.empty())
421 assert(partialType == curPartialType &&
422 "Only one reduction type is supported");
423 partialType = curPartialType;
425 partialAxes.append(axis);
431 splitAxes, partialAxes, partialType);
434 static FailureOr<MeshShardingAttr>
437 Value operandValue = opOperand.
get();
438 auto operandType = cast<RankedTensorType>(operandValue.
getType());
442 int64_t idx = it.index();
444 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
446 if (failed(loopIndices))
449 for (
unsigned loopIdx : *loopIndices) {
452 shardedLoopIndices.push_back(loopIdx);
455 if (shardedLoopIndices.size() > 1)
457 if (shardedLoopIndices.size() == 1) {
458 splitAxes[idx].append(
465 shardingOption.
mesh, splitAxes);
468 FailureOr<SmallVector<MeshShardingAttr>>
473 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
475 shardingOp.getLoopIteratorTypes();
477 shardingOp.getReductionLoopIteratorKinds();
483 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
484 if (failed(shardingAttr))
486 res.push_back(*shardingAttr);
491 result, shardingOption, maps[numOperands + result.getResultNumber()],
492 loopTypes, reductionKinds));
510 result, shardingOption, map, loopTypes, reductionLoopKinds);
522 FailureOr<MeshShardingAttr> shardAttr =
524 if (failed(shardAttr)) {
535 assert(!shardingOption.
empty && shardingOption.
mesh);
537 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
539 shardingOp.getLoopIteratorTypes();
541 shardingOp.getReductionLoopIteratorKinds();
547 if (failed(
addShardOp(b, result, shardingOption,
548 maps[numOperands + result.getResultNumber()],
549 loopTypes, reductionKinds)))
555 if (failed(
addShardOp(b, opOperand, shardingOption,
556 maps[opOperand.getOperandNumber()])))
567 if (isa<RankedTensorType>(value.
getType())) {
574 template <
typename ValueRange,
typename MeshShardingAttrRage>
576 ValueRange &&values, MeshShardingAttrRage &&shardings) {
577 if (std::size(values) != std::size(shardings)) {
580 return llvm::all_of(llvm::zip_equal(
581 std::forward<ValueRange>(values),
582 std::forward<MeshShardingAttrRage>(shardings)),
583 [](
auto valueAndSharding) {
585 std::get<0>(valueAndSharding),
586 std::get<1>(valueAndSharding));
596 assert(spmdizedOperands.size() == operandShardings.size());
602 builder.
clone(op, spmdizationMap);
608 &meshAxesAssignmentForLoopIterators) {
609 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
610 unsigned loopIteratorIdx = affineDimExpr.
getPosition();
611 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
612 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
613 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
615 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
616 llvm::to_vector(meshAxesAssignmentForTensorAxis);
626 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
628 operatorAndResultShardings.reserve(operandShardings.size() +
629 resultShardings.size());
630 llvm::append_range(operatorAndResultShardings, operandShardings);
631 for (
auto [sharding, affineMap] :
632 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
636 for (
auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
637 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
639 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
640 meshAxisAssignmentForLoopIterators);
643 for (
unsigned i = sharding.getSplitAxes().size();
644 i < affineMap.getNumResults(); ++i) {
646 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
651 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
656 return std::move(*axes);
664 for (
auto [loopIteratorType, meshAxisAssignment] :
665 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
666 if (loopIteratorType == utils::IteratorType::reduction &&
667 !meshAxisAssignment.empty()) {
678 for (
auto [loopIteratorType, meshAxisAssignment] :
679 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
680 if (loopIteratorType == utils::IteratorType::reduction) {
681 llvm::append_range(meshAxes, meshAxisAssignment);
695 for (
auto [oldResult, newResult, sharding] :
697 newResult.setType(
shardType(newResult.getType(),
698 getMesh(&op, sharding.getMesh(), symbolTable),
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshShardingAttr sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingAttrRage &&shardings)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
MeshShardingAttr getShardingAttribute(OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
FailureOr< SmallVector< MeshShardingAttr > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(MeshShardingAttr attr)
bool isReductionLoop(utils::IteratorType iType)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
ShardingArray shardingArray