9#ifndef MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
10#define MLIR_DIALECT_SHARD_INTERFACES_SHARDINGINTERFACEIMPL_H_
29 ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings,
30 ArrayRef<utils::IteratorType> loopIteratorTypes,
31 ArrayRef<AffineMap> indexingMaps);
34 ArrayRef<utils::IteratorType> loopIteratorTypes,
35 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
39 ArrayRef<utils::IteratorType> loopIteratorTypes,
40 ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators);
45 ArrayRef<Value> partitionedOperands,
46 ArrayRef<Sharding> operandShardings,
47 ArrayRef<Sharding> resultShardings,
48 IRMapping &partitionMap,
49 SymbolTableCollection &symbolTable,
56 :
public ShardingInterface::ExternalModel<
57 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
62 populateIteratorTypes(t, iterTypes);
65 populateIteratorTypes(t, iterTypes);
82 operandShardings, resultShardings,
83 partitionMap, symbolTable, builder);
89 populateIteratorTypes(
Type t,
91 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
92 if (!rankedTensorType) {
96 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
97 for (
int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
98 iterTypes.push_back(utils::IteratorType::parallel);
104template <
typename ElemwiseOp>
106 :
public ShardingInterface::ExternalModel<
107 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
110 auto type = dyn_cast<RankedTensorType>(val.
getType());
114 utils::IteratorType::parallel);
121 auto type = dyn_cast<RankedTensorType>(val.
getType());
138 operandShardings, resultShardings,
139 partitionMap, symbolTable, builder);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumOperands()
operand_type_range getOperandTypes()
result_type_range getResultTypes()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
void partitionTriviallyShardableOperation(Operation &op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
SmallVector< GridAxis > getReductionGridAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< GridAxis > > gridAxisAssignmentForLoopIterators)
ShardingArray getGridAxisAssignmentForLoopIterators(ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
SmallVector< SmallVector< GridAxis > > ShardingArray
Include the generated interface declarations.
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult partition(Operation *op, ArrayRef< Value > partitionedOperands, ArrayRef< Sharding > operandShardings, ArrayRef< Sharding > resultShardings, IRMapping &partitionMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const