MLIR  20.0.0git
MeshOps.h
Go to the documentation of this file.
1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
11 
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/Support/MathExtras.h"
22 
23 namespace mlir {
24 namespace mesh {
25 
26 using MeshAxis = int16_t;
30 
31 } // namespace mesh
32 } // namespace mlir
33 
34 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
35 
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
38 
39 namespace mlir {
40 namespace mesh {
41 
42 class MeshSharding {
43 private:
45  SmallVector<MeshAxesAttr> split_axes;
46  SmallVector<MeshAxis> partial_axes;
47  ReductionKind partial_type;
48  SmallVector<int64_t> static_halo_sizes;
49  SmallVector<int64_t> static_sharded_dims_offsets;
50  SmallVector<Value> dynamic_halo_sizes;
51  SmallVector<Value> dynamic_sharded_dims_offsets;
52 
53 public:
54  MeshSharding() = default;
55  MeshSharding(Value rhs);
57  ArrayRef<MeshAxesAttr> split_axes_,
58  ArrayRef<MeshAxis> partial_axes_ = {},
59  ReductionKind partial_type_ = ReductionKind::Sum,
60  ArrayRef<int64_t> static_halo_sizes_ = {},
61  ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
62  ArrayRef<Value> dynamic_halo_sizes_ = {},
63  ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
64  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
65  ::llvm::StringRef getMesh() const { return mesh.getValue(); }
66  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
67  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
68  ReductionKind getPartialType() const { return partial_type; }
69  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
71  return static_sharded_dims_offsets;
72  }
73  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
75  return dynamic_sharded_dims_offsets;
76  }
77  operator bool() const { return (!mesh) == false; }
78  bool operator==(Value rhs) const;
79  bool operator!=(Value rhs) const;
80  bool operator==(const MeshSharding &rhs) const;
81  bool operator!=(const MeshSharding &rhs) const;
82  bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
83  bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
84  bool equalHaloSizes(const MeshSharding &rhs) const;
85  bool equalShardSizes(const MeshSharding &rhs) const;
86 };
87 
88 } // namespace mesh
89 } // namespace mlir
90 
91 #define GET_TYPEDEF_CLASSES
92 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
93 
94 #define GET_OP_CLASSES
95 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
96 
97 namespace mlir {
98 namespace mesh {
99 
100 inline bool isReductionLoop(utils::IteratorType iType) {
101  return iType == utils::IteratorType::reduction;
102 }
103 
104 // Remove empty subarrays of `array` until a minimum lengh of one is reached.
105 template <typename T>
107  while (array.size() > 1 && array.back().empty())
108  array.pop_back();
109 }
110 
111 // Is the same tensor replicated on all processes.
112 inline bool isFullReplication(MeshSharding sharding) {
113  return sharding.getPartialAxes().empty() &&
114  llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
115  return axes.asArrayRef().empty();
116  });
117 }
118 
119 inline mesh::MeshOp
121  SymbolTableCollection &symbolTableCollection) {
122  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
123  op, meshSymbol);
124 }
125 
127  SymbolTableCollection &symbolTableCollection) {
128  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
129  assert(meshOp);
130  return meshOp;
131 }
132 
133 // Get the corresponding mesh op using the standard attribute nomenclature.
134 template <typename Op>
135 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
136  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
137 }
138 
139 template <>
141 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
142  return getMesh(
143  op.getOperation(),
144  cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
145  symbolTableCollection);
146 }
147 
148 // Get the number of processes that participate in each group
149 // induced by `meshAxes`.
150 template <typename MeshAxesRange, typename MeshShapeRange>
151 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
152  MeshShapeRange &&meshShape) {
153  int64_t res = 1;
154 
155  for (MeshAxis axis : meshAxes) {
156  auto axisSize = *(std::begin(meshShape) + axis);
157  if (ShapedType::isDynamic(axisSize)) {
158  return ShapedType::kDynamic;
159  }
160  res *= axisSize;
161  }
162 
163  return res;
164 }
165 
166 template <typename MeshAxesRange>
167 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
168  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
169  mesh.getShape());
170 }
171 
172 // Get the size of a sharded dimension.
173 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
174  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
175  return ShapedType::kDynamic;
176 
177  assert(dimSize % shardCount == 0);
178  return dimSize / shardCount;
179 }
180 
181 // Get the size of an unsharded dimension.
182 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
183  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
184  return ShapedType::kDynamic;
185 
186  return dimSize * shardCount;
187 }
188 
189 // Return the sharded shape `shape` according ot sharding `sharding`.
190 // The shape for the tensor on each device in the mesh.
191 // Example:
192 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
193 // result in a shape for each shard of ?x2x?.
194 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
195  MeshSharding sharding);
196 
197 // If ranked tensor type return its sharded counterpart.
198 //
199 // If not ranked tensor type return `type`.
200 // `sharding` in that case must be null.
201 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
202 
203 // Insert shard op if there is not one that already has the same sharding.
204 // May insert resharding if required.
206  OpOperand &operand,
207  OpBuilder &builder);
209  OpBuilder &builder);
211  OpOperand &operand,
212  OpBuilder &builder);
213 
214 } // namespace mesh
215 } // namespace mlir
216 
217 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class helps build Operations.
Definition: Builders.h:215
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:108
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:623
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:653
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:678
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:700
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: MeshOps.h:74
bool operator==(Value rhs) const
Definition: MeshOps.cpp:696
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:722
bool equalShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:657
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:151
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:314
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:182
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:264
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:254
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:126
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:173
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:100
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:141
int16_t MeshAxis
Definition: MeshOps.h:26
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:272
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:106
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr