16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc"
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
62 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
63 if ((
size_t)position >= seenIds.size() || seenIds[position])
65 seenIds[position] =
true;
69 unsigned position = cast<AffineDimExpr>(expr).getPosition();
70 if ((
size_t)position >= seenIds.size() || seenIds[position])
72 seenIds[position] =
true;
86 llvm::SmallSet<unsigned, 2> positions;
89 positions.insert((
unsigned)it.index());
100 Value val = cast<Value>(result);
102 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
105 return !shardOp.getAnnotateForUsers();
108 if (anyShardedForDef) {
113 auto shardOp = llvm::cast<mesh::ShardOp>(*val.
getUsers().begin());
114 return std::make_pair(
false, shardOp.getShard());
118 auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user);
121 return shardOp.getAnnotateForUsers();
123 if (anyShardedForUsers) {
126 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
128 shardOps.push_back(shardOp);
131 for (
size_t i = 1; i < shardOps.size(); ++i) {
134 assert(shardOps[i].getShard() == shardForDef &&
135 "only support all shard ops have the same mesh sharding attr");
137 return std::make_pair(
true, shardForDef);
146 return std::make_pair(shardOp.getAnnotateForUsers(), shardOp.getShard());
155 LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() {
160 if (!llvm::isa<RankedTensorType>(type))
163 if (!llvm::isa<RankedTensorType>(type))
168 if (loopTypes.size() == 0)
173 if (maps.size() == 0)
177 if (numOperands + numResults != maps.size())
181 auto resultType = dyn_cast<RankedTensorType>(result.getType());
184 AffineMap map = maps[numOperands + result.getResultNumber()];
197 void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
198 os <<
"print loop types and indexing maps for: \n";
199 getOperation()->print(os);
201 os <<
"loop types: [";
202 for (utils::IteratorType type : getLoopIteratorTypes()) {
203 os << stringifyEnum(type) <<
" ";
206 os <<
"indexing maps: \n";
224 if ((shardingOption.
mesh && mesh && shardingOption.
mesh != mesh) ||
227 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts on loop iterator "
231 for (
size_t i = 0; i < shardingOption.
shardingArray.size(); ++i) {
236 if (llvm::is_contained(shardingOption.
shardingArray[i], axis)) {
237 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts because mesh axes "
238 << axis <<
" duplicate");
244 shardingOption.
mesh = mesh;
246 shardingOption.
shardingArray[loopIdx].append(meshAxes.begin(),
256 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
259 if (
failed(shardingOp.verifyShardingInterfaceImpl()))
260 return op->
emitOpError() <<
"invalid sharding interface implementation";
262 shardingOp.getLoopIteratorTypes();
267 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
268 bool anyShardingInResultsOrOperands =
false;
275 AffineMap map = maps[numOperands + shardingIt.index()];
276 anyShardingInResultsOrOperands =
true;
280 for (
auto it : llvm::zip(map.
getResults(), shardAttr.getSplitAxes())) {
283 auto dim = cast<AffineDimExpr>(expr);
284 unsigned index = dim.getPosition();
285 visitedLoopIndices.insert(index);
286 if (
failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
294 if (!partialAxes.empty()) {
295 if (!partialMeshAxes.empty())
296 return op->
emitOpError() <<
"at most one result with partial axes is "
297 "supported at present";
298 partialMeshAxes.append(partialAxes.begin(), partialAxes.end());
301 for (
size_t loopIdx = 0; loopIdx < loopTypes.size(); ++loopIdx) {
303 visitedLoopIndices.insert(loopIdx);
314 anyShardingInResultsOrOperands =
true;
315 AffineMap map = maps[shardingIt.index()];
323 for (
auto it : llvm::zip(map.
getResults(), shardAttr.getSplitAxes())) {
330 <<
"operand's affine expression is restricted to const_i * "
331 "dim_i + const_j + dim_j + ...";
332 if (loopIndices->empty())
334 if (loopIndices->size() == 1) {
335 unsigned loopIdx = *loopIndices->begin();
336 visitedLoopIndices.insert(loopIdx);
337 if (
failed(fillShardingOption(op, shardingOption, shardAttr.getMesh(),
344 if (loopIndices->size() > 1) {
345 bool seenLoopIndices =
false;
346 for (
unsigned loopIdx : *loopIndices) {
347 if (visitedLoopIndices.contains(loopIdx)) {
348 seenLoopIndices =
true;
352 if (!seenLoopIndices)
354 <<
"the operand " << shardingIt.index()
355 <<
" has multiple loop indices in a dimension, but none of "
356 "them could be found in the exactly specified annotation "
357 "of op results or operands.";
363 if (!partialMeshAxes.empty()) {
364 bool anyNonEmptyReductionLoop = llvm::any_of(
366 SmallVector<MeshAxis> &subArray = it.value();
367 int64_t idx = it.index();
368 return isReductionLoop(loopTypes[idx]) && !subArray.empty();
370 if (!anyNonEmptyReductionLoop) {
372 for (
size_t idx = 0; idx < loopTypes.size(); ++idx) {
374 std::ignore = fillShardingOption(op, shardingOption,
nullptr,
375 partialMeshAxes, idx);
381 return op->
emitOpError() <<
"no matched reduction loop found for the "
382 "result's partial type";
386 if (!anyShardingInResultsOrOperands)
387 shardingOption.
empty =
true;
388 return shardingOption;
404 if (
succeeded(maybeSharding) && !maybeSharding->first)
407 auto resultType = cast<RankedTensorType>(result.
getType());
416 auto dim = cast<AffineDimExpr>(expr);
417 unsigned loopIdx = dim.getPosition();
419 splitAxes[it.index()].append(shardingOption.
shardingArray[loopIdx]);
425 size_t reductionLoopKindsIdx = 0;
426 for (
auto it : llvm::zip(loopTypes, shardingOption.
shardingArray)) {
427 utils::IteratorType iType = std::get<0>(it);
429 ReductionKind curPartialType = reductionLoopKinds[reductionLoopKindsIdx];
430 ++reductionLoopKindsIdx;
431 if (!partialAxes.empty())
432 assert(partialType == curPartialType &&
433 "Only one reduction type is supported");
434 partialType = curPartialType;
436 partialAxes.append(axis);
442 b.
getContext(), shardingOption.
mesh, splitAxes, partialAxes, partialType);
445 auto shardOp = b.
create<ShardOp>(result.
getLoc(), resultType, result,
457 if (
succeeded(maybeShardingAttr) && maybeShardingAttr->first)
460 auto operandType = cast<RankedTensorType>(operand.
getType());
464 int64_t idx = it.index();
471 for (
unsigned loopIdx : *loopIndices) {
474 shardedLoopIndices.push_back(loopIdx);
477 if (shardedLoopIndices.size() > 1)
479 if (shardedLoopIndices.size() == 1) {
480 splitAxes[idx].append(
490 auto shardOp = b.
create<ShardOp>(operand.
getLoc(), operandType, operand,
492 opOperand.
set(shardOp);
499 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
501 shardingOp.getLoopIteratorTypes();
503 shardingOp.getReductionLoopIteratorKinds();
510 maps[numOperands + result.getResultNumber()],
511 loopTypes, reductionKinds)))
518 maps[opOperand.getOperandNumber()])))
529 if (isa<RankedTensorType>(value.
getType())) {
536 template <
typename ValueRange,
typename MeshShardingAttrRage>
538 ValueRange &&values, MeshShardingAttrRage &&shardings) {
539 if (std::size(values) != std::size(shardings)) {
542 return llvm::all_of(llvm::zip_equal(
543 std::forward<ValueRange>(values),
544 std::forward<MeshShardingAttrRage>(shardings)),
545 [](
auto valueAndSharding) {
547 std::get<0>(valueAndSharding),
548 std::get<1>(valueAndSharding));
558 assert(spmdizedOperands.size() == operandShardings.size());
564 builder.
clone(op, spmdizationMap);
570 &meshAxesAssignmentForLoopIterators) {
571 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
572 unsigned loopIteratorIdx = affineDimExpr.
getPosition();
573 if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
574 assert(llvm::equal(meshAxesAssignmentForTensorAxis,
575 *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
577 meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
578 llvm::to_vector(meshAxesAssignmentForTensorAxis);
588 meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
590 operatorAndResultShardings.reserve(operandShardings.size() +
591 resultShardings.size());
592 llvm::append_range(operatorAndResultShardings, operandShardings);
593 for (
auto [sharding, affineMap] :
594 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
598 for (
auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
599 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
601 meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
602 meshAxisAssignmentForLoopIterators);
605 for (
unsigned i = sharding.getSplitAxes().size();
606 i < affineMap.getNumResults(); ++i) {
608 {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
613 llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
618 return std::move(*axes);
626 for (
auto [loopIteratorType, meshAxisAssignment] :
627 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
628 if (loopIteratorType == utils::IteratorType::reduction &&
629 !meshAxisAssignment.empty()) {
640 for (
auto [loopIteratorType, meshAxisAssignment] :
641 llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
642 if (loopIteratorType == utils::IteratorType::reduction) {
643 llvm::append_range(meshAxes, meshAxisAssignment);
657 for (
auto [oldResult, newResult, sharding] :
659 newResult.setType(
shardType(newResult.getType(),
660 getMesh(&op, sharding.getMesh(), symbolTable),
static void updateMeshAxisAssignmentForLoopIterators(ArrayRef< MeshAxis > meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< MeshAxis >>> &meshAxesAssignmentForLoopIterators)
static bool isValueCompatibleWithFullReplicationSharding(Value value, MeshShardingAttr sharding)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes, ArrayRef< ReductionKind > reductionLoopKinds)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, MeshShardingAttrRage &&shardings)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
void set(IRValueT newValue)
Set the current value being used by this operand.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void replaceAllUsesExcept(Value newValue, const SmallPtrSetImpl< Operation * > &exceptions)
Replace all uses of 'this' value with 'newValue', updating anything in the IR that uses 'this' to use...
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Operation * getOwner() const
Return the owner of this operand.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
mesh::ReductionKind ReductionKind
mesh::MeshShardingAttr MeshShardingAttr
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
FailureOr< std::pair< bool, MeshShardingAttr > > getMeshShardingAttr(OpResult result)
void spmdizeFullyReplicatedOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(MeshShardingAttr attr)
bool isReductionLoop(utils::IteratorType iType)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
ShardingArray shardingArray