9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
21 class SymbolTableCollection;
29 ArrayRef<MeshSharding> operandShardings,
30 ArrayRef<MeshSharding> resultShardings,
31 ArrayRef<utils::IteratorType> loopIteratorTypes,
32 ArrayRef<AffineMap> indexingMaps);
35 ArrayRef<utils::IteratorType> loopIteratorTypes,
36 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
40 ArrayRef<utils::IteratorType> loopIteratorTypes,
41 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
46 ArrayRef<Value> spmdizedOperands,
47 ArrayRef<MeshSharding> operandShardings,
48 ArrayRef<MeshSharding> resultShardings,
49 IRMapping &spmdizationMap,
50 SymbolTableCollection &symbolTable,
55 template <
typename Op>
57 :
public ShardingInterface::ExternalModel<
58 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
63 populateIteratorTypes(t, iterTypes);
66 populateIteratorTypes(t, iterTypes);
83 resultShardings, spmdizationMap,
84 symbolTable, builder);
90 populateIteratorTypes(
Type t,
92 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
93 if (!rankedTensorType) {
97 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
98 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
99 iterTypes.push_back(utils::IteratorType::parallel);
105 template <
typename ElemwiseOp>
107 :
public ShardingInterface::ExternalModel<
108 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
111 auto type = dyn_cast<RankedTensorType>(val.
getType());
115 utils::IteratorType::parallel);
122 auto type = dyn_cast<RankedTensorType>(val.
getType());
125 int64_t rank = type.getRank();
139 resultShardings, spmdizationMap,
140 symbolTable, builder);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumOperands()
operand_type_range getOperandTypes()
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Include the generated interface declarations.
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshSharding > operandShardings, ArrayRef< MeshSharding > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const