9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
21 class SymbolTableCollection;
29 ArrayRef<MeshShardingAttr> operandShardings,
30 ArrayRef<MeshShardingAttr> resultShardings,
31 ArrayRef<utils::IteratorType> loopIteratorTypes,
32 ArrayRef<AffineMap> indexingMaps);
35 ArrayRef<utils::IteratorType> loopIteratorTypes,
36 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
40 ArrayRef<utils::IteratorType> loopIteratorTypes,
41 ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
46 Operation &op, ArrayRef<Value> spmdizedOperands,
47 ArrayRef<MeshShardingAttr> operandShardings,
48 ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
49 SymbolTableCollection &symbolTable, OpBuilder &builder);
53 template <
typename Op>
55 :
public ShardingInterface::ExternalModel<
56 IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
61 populateIteratorTypes(t, iterTypes);
64 populateIteratorTypes(t, iterTypes);
81 resultShardings, spmdizationMap,
82 symbolTable, builder);
88 populateIteratorTypes(
Type t,
90 RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
91 if (!rankedTensorType) {
95 iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
96 for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
97 iterTypes.push_back(utils::IteratorType::parallel);
103 template <
typename ElemwiseOp>
105 :
public ShardingInterface::ExternalModel<
106 ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
108 Value val = op->getOperand(0);
109 auto type = dyn_cast<RankedTensorType>(val.
getType());
113 utils::IteratorType::parallel);
119 Value val = op->getOperand(0);
120 auto type = dyn_cast<RankedTensorType>(val.
getType());
123 int64_t rank = type.getRank();
124 int64_t num = op->getNumOperands() + op->getNumResults();
137 resultShardings, spmdizationMap,
138 symbolTable, builder);
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
This is a utility class for mapping one set of IR entities to another.
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
This class represents a collection of SymbolTables.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const