9 #ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
21 class SymbolTableCollection;
29 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
30 ArrayRef<utils::IteratorType> loopIteratorTypes,
31 ArrayRef<AffineMap> indexingMaps);
34 ArrayRef<utils::IteratorType> loopIteratorTypes,
35 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
39 ArrayRef<utils::IteratorType> loopIteratorTypes,
40 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
45 ArrayRef<Value> partitionedOperands,
46 ArrayRef<Sharding> operandShardings,
47 ArrayRef<Sharding> resultShardings,
48 IRMapping &partitionMap,
49 SymbolTableCollection &symbolTable,
54 template <
typename Op>
56 :
public ShardingInterface::ExternalModel<
57 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
62 populateIteratorTypes(t, iterTypes);
65 populateIteratorTypes(t, iterTypes);
82 operandShardings, resultShardings,
83 partitionMap, symbolTable, builder);
89 populateIteratorTypes(
Type t,
91 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
92 if (!rankedTensorType) {
96 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
97 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
98 iterTypes.push_back(utils::IteratorType::parallel);
104 template <
typename ElemwiseOp>
106 :
public ShardingInterface::ExternalModel<
107 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
110 auto type = dyn_cast<RankedTensorType>(val.
getType());
114 utils::IteratorType::parallel);
121 auto type = dyn_cast<RankedTensorType>(val.
getType());
124 int64_t rank = type.getRank();
138 operandShardings, resultShardings,
139 partitionMap, symbolTable, builder);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
operand_type_range getOperandTypes()
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis >> gridAxisAssignmentForLoopIterators)
SmallVector< SmallVector< GridAxis > > ShardingArray
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
Include the generated interface declarations.
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const