16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "sharding-interface"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
29 #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc"
40 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
50 auto binOpExpr = cast<AffineBinaryOpExpr>(expr);
63 unsigned position = cast<AffineDimExpr>(dimExpr).getPosition();
64 if ((
size_t)position >= seenIds.size() || seenIds[position])
66 seenIds[position] =
true;
70 unsigned position = cast<AffineDimExpr>(expr).getPosition();
71 if ((
size_t)position >= seenIds.size() || seenIds[position])
73 seenIds[position] =
true;
81 static FailureOr<llvm::SmallSet<unsigned, 2>>
87 llvm::SmallSet<unsigned, 2> positions;
90 positions.insert((
unsigned)it.index());
99 for (
const auto &v : vec) {
110 Value val = cast<Value>(result);
112 auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
115 return !shardOp.getAnnotateForUsers();
118 if (anyShardedForDef) {
123 auto shardOp = llvm::cast<shard::ShardOp>(*val.
getUsers().begin());
124 return std::make_pair(
false,
Sharding(shardOp.getSharding()));
128 auto shardOp = llvm::dyn_cast<shard::ShardOp>(user);
131 return shardOp.getAnnotateForUsers();
133 if (anyShardedForUsers) {
136 ShardOp shardOp = llvm::dyn_cast<ShardOp>(user);
138 shardOps.push_back(shardOp);
140 Sharding shardForDef = shardOps[0].getSharding();
141 for (
size_t i = 1; i < shardOps.size(); ++i) {
145 "only support all shard ops have the same grid sharding attr");
147 return std::make_pair(
true, shardForDef);
155 return std::make_pair(shardOp.getAnnotateForUsers(),
165 LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() {
170 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
173 if (!llvm::isa<RankedTensorType>(type) && !type.isIntOrIndexOrFloat())
182 if (numOperands + numResults != maps.size())
186 auto resultType = dyn_cast<RankedTensorType>(result.getType());
189 AffineMap map = maps[numOperands + result.getResultNumber()];
202 void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) {
203 os <<
"print loop types and indexing maps for: \n";
204 getOperation()->print(os);
206 os <<
"loop types: [";
207 for (utils::IteratorType type : getLoopIteratorTypes()) {
208 os << stringifyEnum(type) <<
" ";
211 os <<
"indexing maps: \n";
224 static LogicalResult fillShardingOption(
Operation *op,
229 if ((shardingOption.
grid && grid && shardingOption.
grid != grid) ||
232 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts on loop iterator "
236 for (
size_t i = 0; i < shardingOption.
shardingArray.size(); ++i) {
241 if (llvm::is_contained(shardingOption.
shardingArray[i], axis)) {
242 LLVM_DEBUG(
DBGS() <<
"sharding option conflicts because grid axes "
243 << axis <<
" duplicate");
249 shardingOption.
grid = grid;
251 shardingOption.
shardingArray[loopIdx].append(gridAxes.begin(),
258 FailureOr<ShardingOption>
262 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
265 if (
failed(shardingOp.verifyShardingInterfaceImpl()))
266 return op->
emitOpError() <<
"invalid sharding interface implementation";
268 shardingOp.getLoopIteratorTypes();
272 llvm::SmallSet<unsigned, 4> visitedLoopIndices;
273 bool anyShardingInResultsOrOperands =
false;
277 Sharding shardAttr = shardingIt.value();
280 AffineMap map = maps[numOperands + shardingIt.index()];
281 anyShardingInResultsOrOperands =
true;
291 auto dim = cast<AffineDimExpr>(expr);
292 unsigned index = dim.getPosition();
293 visitedLoopIndices.insert(index);
294 if (
failed(fillShardingOption(op, shardingOption,
303 Sharding shardAttr = shardingIt.value();
307 anyShardingInResultsOrOperands = !shardAttr.
getSplitAxes().empty();
308 AffineMap map = maps[shardingIt.index()];
318 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
322 <<
"operand's affine expression is restricted to const_i * "
323 "dim_i + const_j + dim_j + ...";
324 if (loopIndices->empty())
326 if (loopIndices->size() == 1) {
327 unsigned loopIdx = *loopIndices->begin();
328 visitedLoopIndices.insert(loopIdx);
329 if (
failed(fillShardingOption(op, shardingOption,
336 if (loopIndices->size() > 1) {
337 bool seenLoopIndices =
false;
338 for (
unsigned loopIdx : *loopIndices) {
339 if (visitedLoopIndices.contains(loopIdx)) {
340 seenLoopIndices =
true;
344 if (!seenLoopIndices)
346 <<
"the operand " << shardingIt.index()
347 <<
" has multiple loop indices in a dimension, but none of "
348 "them could be found in the exactly specified annotation "
349 "of op results or operands.";
356 if (!anyShardingInResultsOrOperands)
357 shardingOption.
empty =
true;
358 return shardingOption;
365 auto resultType = cast<RankedTensorType>(result.
getType());
373 auto dim = cast<AffineDimExpr>(expr);
374 unsigned loopIdx = dim.getPosition();
376 splitAxes[it.index()].append(shardingOption.
shardingArray[loopIdx]);
387 Value operandValue = opOperand.
get();
388 auto operandType = dyn_cast<RankedTensorType>(operandValue.
getType());
395 if (operandType.getRank() == 0) {
401 int64_t idx = it.index();
403 FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices =
408 for (
unsigned loopIdx : *loopIndices) {
411 shardedLoopIndices.push_back(loopIdx);
414 if (shardedLoopIndices.size() > 1)
416 if (shardedLoopIndices.size() == 1) {
417 splitAxes[idx].append(
430 std::vector<Sharding> res;
432 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
434 shardingOp.getLoopIteratorTypes();
440 opOperand, shardingOption, maps[opOperand.getOperandNumber()]);
443 res.push_back(*shardingAttr);
447 res.push_back(::
getSharding(result, shardingOption,
448 maps[numOperands + result.getResultNumber()],
477 FailureOr<Sharding> sharding =
getSharding(opOperand, shardingOption, map);
489 assert(!shardingOption.
empty && shardingOption.
grid);
491 ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op);
493 shardingOp.getLoopIteratorTypes();
500 maps[numOperands + result.getResultNumber()],
508 maps[opOperand.getOperandNumber()])))
519 if (isa<RankedTensorType>(value.
getType())) {
526 template <
typename ValueRange,
typename ShardingRage>
529 ShardingRage &&shardings) {
530 if (std::size(values) != std::size(shardings)) {
533 return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values),
534 std::forward<ShardingRage>(shardings)),
535 [](
auto valueAndSharding) {
537 std::get<0>(valueAndSharding),
538 std::get<1>(valueAndSharding));
548 assert(partitionedOperands.size() == operandShardings.size());
554 builder.
clone(op, partitionMap);
560 &gridAxesAssignmentForLoopIterators) {
561 AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
562 unsigned loopIteratorIdx = affineDimExpr.
getPosition();
563 if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) {
564 assert(llvm::equal(gridAxesAssignmentForTensorAxis,
565 *gridAxesAssignmentForLoopIterators[loopIteratorIdx]));
567 gridAxesAssignmentForLoopIterators[loopIteratorIdx] =
568 llvm::to_vector(gridAxesAssignmentForTensorAxis);
577 gridAxisAssignmentForLoopIterators(loopIteratorTypes.size());
578 std::vector<Sharding> operatorAndResultShardings;
579 operatorAndResultShardings.reserve(operandShardings.size() +
580 resultShardings.size());
581 llvm::append_range(operatorAndResultShardings, operandShardings);
582 for (
auto [sharding, affineMap] :
583 llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
587 for (
auto [gridAxesAssignmentForTensorAxis, indexingExpr] :
588 llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
590 gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
591 gridAxisAssignmentForLoopIterators);
594 for (
unsigned i = sharding.getSplitAxes().size();
595 i < affineMap.getNumResults(); ++i) {
597 {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators);
602 llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res),
607 return std::move(*axes);
615 for (
auto [loopIteratorType, gridAxisAssignment] :
616 llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
617 if (loopIteratorType == utils::IteratorType::reduction &&
618 !gridAxisAssignment.empty()) {
629 for (
auto [loopIteratorType, gridAxisAssignment] :
630 llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) {
631 if (loopIteratorType == utils::IteratorType::reduction) {
632 llvm::append_range(gridAxes, gridAxisAssignment);
646 for (
auto [oldResult, newResult, sharding] :
650 getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding));
static bool isValueCompatibleWithFullReplicationSharding(Value value, const Sharding &sharding)
static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef< utils::IteratorType > loopTypes)
static FailureOr< llvm::SmallSet< unsigned, 2 > > checkOperandAffineExpr(AffineExpr expr, unsigned numDims)
static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, ShardingRage &&shardings)
static LogicalResult checkOperandAffineExprRecursively(AffineExpr expr, SmallVectorImpl< bool > &seenIds)
static void updateGridAxisAssignmentForLoopIterators(ArrayRef< GridAxis > gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, SmallVector< std::optional< SmallVector< GridAxis >>> &gridAxesAssignmentForLoopIterators)
SmallVector< GridAxesAttr > fromArrayOfVector(MLIRContext *ctxt, const SmallVector< SmallVector< T >> &vec)
A dimensional identifier appearing in an affine expression.
unsigned getPosition() const
Base type for affine expression.
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
IRValueT get() const
Return the current value being used by this operand.
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
unsigned getNumOperands()
operand_type_range getOperandTypes()
MutableArrayRef< OpOperand > getOpOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
user_range getUsers() const
bool hasOneUse() const
Returns true if this value has exactly one use.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
static Sharding get(::mlir::FlatSymbolRefAttr grid_, ArrayRef< GridAxesAttr > split_axes_, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
::mlir::FlatSymbolRefAttr getGridAttr() const
ArrayRef< GridAxesAttr > getSplitAxes() const
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LogicalResult defaultAddShardingAnnotations(Operation *op, OpBuilder &b, const ShardingOption &shardingOption)
FailureOr< std::vector< Sharding > > defaultGetShardingAnnotations(Operation *op, const ShardingOption &shardingOption)
FailureOr< ShardingOption > defaultGetShardingOption(Operation *op, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
void maybeInsertTargetShardingAnnotation(Sharding sharding, OpResult result, OpBuilder &builder)
void partitionFullyReplicatedOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
shard::GridOp getGridOrNull(Operation *op, FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTableCollection)
bool isFullReplication(Sharding sharding)
void maybeInsertSourceShardingAnnotation(Sharding sharding, OpOperand &operand, OpBuilder &builder)
FailureOr< std::pair< bool, Sharding > > getSharding(OpResult result)
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Type shardType(Type type, GridOp grid, Sharding sharding)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
Include the generated interface declarations.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ DimId
Dimensional identifier.
@ Constant
Constant integer.
ShardingArray shardingArray