MLIR  20.0.0git
MeshOps.h
Go to the documentation of this file.
1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
11 
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
20 #include "llvm/Support/MathExtras.h"
21 
22 namespace mlir {
23 namespace mesh {
24 
25 using MeshAxis = int16_t;
29 
30 } // namespace mesh
31 } // namespace mlir
32 
33 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
34 
35 #define GET_ATTRDEF_CLASSES
36 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
37 
38 namespace mlir {
39 namespace mesh {
40 
41 class MeshSharding {
42 private:
44  SmallVector<MeshAxesAttr> split_axes;
45  SmallVector<MeshAxis> partial_axes;
46  ReductionKind partial_type;
47  SmallVector<int64_t> static_halo_sizes;
48  SmallVector<int64_t> static_sharded_dims_sizes;
49  SmallVector<Value> dynamic_halo_sizes;
50  SmallVector<Value> dynamic_sharded_dims_sizes;
51 
52 public:
53  MeshSharding() = default;
54  MeshSharding(Value rhs);
56  ArrayRef<MeshAxesAttr> split_axes_,
57  ArrayRef<MeshAxis> partial_axes_ = {},
58  ReductionKind partial_type_ = ReductionKind::Sum,
59  ArrayRef<int64_t> static_halo_sizes_ = {},
60  ArrayRef<int64_t> static_sharded_dims_sizes_ = {},
61  ArrayRef<Value> dynamic_halo_sizes_ = {},
62  ArrayRef<Value> dynamic_sharded_dims_sizes_ = {});
63  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
64  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
65  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
66  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
67  ReductionKind getPartialType() const { return partial_type; }
68  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
70  return static_sharded_dims_sizes;
71  }
72  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
74  return dynamic_sharded_dims_sizes;
75  }
76  operator bool() const { return (!mesh) == false; }
77  bool operator==(Value rhs) const;
78  bool operator!=(Value rhs) const;
79  bool operator==(const MeshSharding &rhs) const;
80  bool operator!=(const MeshSharding &rhs) const;
81  bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
82  bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
83 };
84 
85 } // namespace mesh
86 } // namespace mlir
87 
88 #define GET_TYPEDEF_CLASSES
89 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
90 
91 #define GET_OP_CLASSES
92 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
93 
94 namespace mlir {
95 namespace mesh {
96 
97 inline bool isReductionLoop(utils::IteratorType iType) {
98  return iType == utils::IteratorType::reduction;
99 }
100 
101 template <typename T>
103  while (!array.empty() && array.back().empty())
104  array.pop_back();
105 }
106 
107 // Is the same tensor replicated on all processes.
108 inline bool isFullReplication(MeshSharding sharding) {
109  return sharding.getPartialAxes().empty() &&
110  llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
111  return axes.asArrayRef().empty();
112  });
113 }
114 
115 inline mesh::MeshOp
117  SymbolTableCollection &symbolTableCollection) {
118  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
119  op, meshSymbol);
120 }
121 
123  SymbolTableCollection &symbolTableCollection) {
124  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
125  assert(meshOp);
126  return meshOp;
127 }
128 
129 // Get the corresponding mesh op using the standard attribute nomenclature.
130 template <typename Op>
131 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
132  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
133 }
134 
135 template <>
137 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
138  return getMesh(
139  op.getOperation(),
140  cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
141  symbolTableCollection);
142 }
143 
144 // Get the number of processes that participate in each group
145 // induced by `meshAxes`.
146 template <typename MeshAxesRange, typename MeshShapeRange>
147 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
148  MeshShapeRange &&meshShape) {
149  int64_t res = 1;
150 
151  for (MeshAxis axis : meshAxes) {
152  auto axisSize = *(std::begin(meshShape) + axis);
153  if (ShapedType::isDynamic(axisSize)) {
154  return ShapedType::kDynamic;
155  }
156  res *= axisSize;
157  }
158 
159  return res;
160 }
161 
162 template <typename MeshAxesRange>
163 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
164  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
165  mesh.getShape());
166 }
167 
168 // Get the size of a sharded dimension.
169 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
170  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
171  return ShapedType::kDynamic;
172 
173  assert(dimSize % shardCount == 0);
174  return dimSize / shardCount;
175 }
176 
177 // Get the size of an unsharded dimension.
178 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
179  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
180  return ShapedType::kDynamic;
181 
182  return dimSize * shardCount;
183 }
184 
185 // Return the sharded shape `shape` according ot sharding `sharding`.
186 // The shape for the tensor on each device in the mesh.
187 // Example:
188 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
189 // result in a shape for each shard of ?x2x?.
190 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
191  MeshSharding sharding);
192 
193 // If ranked tensor type return its sharded counterpart.
194 //
195 // If not ranked tensor type return `type`.
196 // `sharding` in that case must be null.
197 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
198 
199 // Insert shard op if there is not one that already has the same sharding.
200 // May insert resharding if required.
202  OpOperand &operand,
203  OpBuilder &builder);
205  OpBuilder &builder);
207  OpOperand &operand,
208  OpBuilder &builder);
209 
210 } // namespace mesh
211 } // namespace mlir
212 
213 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class helps build Operations.
Definition: Builders.h:212
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:553
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:578
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:63
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:65
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_sizes_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_sizes_={})
Definition: MeshOps.cpp:637
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:615
ReductionKind getPartialType() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicShardedDimsSizes() const
Definition: MeshOps.h:73
ArrayRef< int64_t > getStaticShardedDimsSizes() const
Definition: MeshOps.h:69
bool operator==(Value rhs) const
Definition: MeshOps.cpp:611
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:66
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:72
::llvm::StringRef getMesh() const
Definition: MeshOps.h:64
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:68
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:147
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:116
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:313
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:178
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:263
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:253
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:122
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:169
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:97
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:108
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:137
int16_t MeshAxis
Definition: MeshOps.h:25
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:271
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:102
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr