MLIR  19.0.0git
ShardingInterfaceImpl.h
Go to the documentation of this file.
1 //===- ShardingInterfaceImpl.h ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
10 #define MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
11 
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/Value.h"
16 
17 namespace mlir {
18 
19 class Operation;
20 class IRMapping;
21 class SymbolTableCollection;
22 
23 namespace mesh {
24 
25 // Retrieve the mesh axes corresponding to each operation loop iterator based
26 // on the provided shardings for the op's operands and results.
27 // Assumes that the indexingMaps are projected permutations.
29  ArrayRef<MeshShardingAttr> operandShardings,
30  ArrayRef<MeshShardingAttr> resultShardings,
31  ArrayRef<utils::IteratorType> loopIteratorTypes,
32  ArrayRef<AffineMap> indexingMaps);
33 
35  ArrayRef<utils::IteratorType> loopIteratorTypes,
36  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
37 
38 // Get the set of mesh axes that correspond to reduction loop iterators.
39 SmallVector<MeshAxis> getReductionMeshAxes(
40  ArrayRef<utils::IteratorType> loopIteratorTypes,
41  ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
42 
43 // Inserts a clone of the operation that has all ranked tensor
44 // arguments/results sharded.
46  Operation &op, ArrayRef<Value> spmdizedOperands,
47  ArrayRef<MeshShardingAttr> operandShardings,
48  ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
49  SymbolTableCollection &symbolTable, OpBuilder &builder);
50 
51 // All ranked tensor argument and result dimensions have
52 // independent parallel loop iterators.
53 template <typename Op>
55  : public ShardingInterface::ExternalModel<
56  IndependentParallelIteratorDomainShardingInterface<Op>, Op> {
58  getLoopIteratorTypes(Operation *operation) const {
60  for (Type t : operation->getOperandTypes()) {
61  populateIteratorTypes(t, iterTypes);
62  }
63  for (Type t : operation->getResultTypes()) {
64  populateIteratorTypes(t, iterTypes);
65  }
66  return iterTypes;
67  }
68 
70  // TODO: implement.
71  return SmallVector<AffineMap>();
72  }
73 
74  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
75  ArrayRef<MeshShardingAttr> operandShardings,
76  ArrayRef<MeshShardingAttr> resultShardings,
77  IRMapping &spmdizationMap,
78  SymbolTableCollection &symbolTable,
79  OpBuilder &builder) const {
80  spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
81  resultShardings, spmdizationMap,
82  symbolTable, builder);
83  return success();
84  }
85 
86 private:
87  void
88  populateIteratorTypes(Type t,
89  SmallVector<utils::IteratorType> &iterTypes) const {
90  RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(t);
91  if (!rankedTensorType) {
92  return;
93  }
94 
95  iterTypes.reserve(iterTypes.size() + rankedTensorType.getRank());
96  for (int64_t i = 0; i < rankedTensorType.getRank(); ++i) {
97  iterTypes.push_back(utils::IteratorType::parallel);
98  }
99  }
100 };
101 
102 // Sharding of elementwise operations like tensor addition and multiplication.
103 template <typename ElemwiseOp>
105  : public ShardingInterface::ExternalModel<
106  ElementwiseShardingInterface<ElemwiseOp>, ElemwiseOp> {
108  Value val = op->getOperand(0);
109  auto type = dyn_cast<RankedTensorType>(val.getType());
110  if (!type)
111  return {};
112  SmallVector<utils::IteratorType> types(type.getRank(),
113  utils::IteratorType::parallel);
114  return types;
115  }
116 
118  MLIRContext *ctx = op->getContext();
119  Value val = op->getOperand(0);
120  auto type = dyn_cast<RankedTensorType>(val.getType());
121  if (!type)
122  return {};
123  int64_t rank = type.getRank();
124  int64_t num = op->getNumOperands() + op->getNumResults();
125  SmallVector<AffineMap> maps(num,
127  return maps;
128  }
129 
130  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
131  ArrayRef<MeshShardingAttr> operandShardings,
132  ArrayRef<MeshShardingAttr> resultShardings,
133  IRMapping &spmdizationMap,
134  SymbolTableCollection &symbolTable,
135  OpBuilder &builder) const {
136  spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
137  resultShardings, spmdizationMap,
138  symbolTable, builder);
139  return success();
140  }
141 };
142 
143 } // namespace mesh
144 } // namespace mlir
145 
146 #endif // MLIR_DIALECT_MESH_INTERFACES_SHARDINGINTERFACEIMPL_H_
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:321
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool isAtLeastOneReductionIteratorSharded(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
SmallVector< SmallVector< MeshAxis > > ShardingArray
SmallVector< MeshAxis > getReductionMeshAxes(ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< SmallVector< MeshAxis >> meshAxisAssignmentForLoopIterators)
ShardingArray getMeshAxisAssignmentForLoopIterators(ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, ArrayRef< utils::IteratorType > loopIteratorTypes, ArrayRef< AffineMap > indexingMaps)
void spmdizeTriviallyShardableOperation(Operation &op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder)
Include the generated interface declarations.
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *op) const
SmallVector< AffineMap > getIndexingMaps(Operation *op) const
SmallVector< utils::IteratorType > getLoopIteratorTypes(Operation *operation) const
LogicalResult spmdize(Operation *op, ArrayRef< Value > spmdizedOperands, ArrayRef< MeshShardingAttr > operandShardings, ArrayRef< MeshShardingAttr > resultShardings, IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, OpBuilder &builder) const