MLIR  21.0.0git
MeshOps.h
Go to the documentation of this file.
1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
11 
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
21 #include "llvm/Support/MathExtras.h"
22 
23 namespace mlir {
24 namespace mesh {
25 
26 using MeshAxis = int16_t;
30 
31 } // namespace mesh
32 } // namespace mlir
33 
34 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
35 
36 #define GET_ATTRDEF_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
38 
39 namespace mlir {
40 namespace mesh {
41 
42 class MeshSharding {
43 private:
45  SmallVector<MeshAxesAttr> split_axes;
46  SmallVector<MeshAxis> partial_axes;
47  ReductionKind partial_type;
48  SmallVector<int64_t> static_halo_sizes;
49  SmallVector<int64_t> static_sharded_dims_offsets;
50  SmallVector<Value> dynamic_halo_sizes;
51  SmallVector<Value> dynamic_sharded_dims_offsets;
52 
53 public:
54  MeshSharding(::mlir::FlatSymbolRefAttr mesh_ = nullptr);
55  MeshSharding(Value rhs);
57  ArrayRef<MeshAxesAttr> split_axes_,
58  ArrayRef<MeshAxis> partial_axes_ = {},
59  ReductionKind partial_type_ = ReductionKind::Sum,
60  ArrayRef<int64_t> static_halo_sizes_ = {},
61  ArrayRef<int64_t> static_sharded_dims_offsets_ = {},
62  ArrayRef<Value> dynamic_halo_sizes_ = {},
63  ArrayRef<Value> dynamic_sharded_dims_offsets_ = {});
64  ::mlir::FlatSymbolRefAttr getMeshAttr() const { return mesh; }
65  ::llvm::StringRef getMesh() const { return mesh ? mesh.getValue() : ""; }
66  ArrayRef<MeshAxesAttr> getSplitAxes() const { return split_axes; }
67  ArrayRef<MeshAxis> getPartialAxes() const { return partial_axes; }
68  ReductionKind getPartialType() const { return partial_type; }
69  ArrayRef<int64_t> getStaticHaloSizes() const { return static_halo_sizes; }
71  return static_sharded_dims_offsets;
72  }
73  ArrayRef<Value> getDynamicHaloSizes() const { return dynamic_halo_sizes; }
75  return dynamic_sharded_dims_offsets;
76  }
77  operator bool() const { return (!mesh) == false; }
78  bool operator==(Value rhs) const;
79  bool operator!=(Value rhs) const;
80  bool operator==(const MeshSharding &rhs) const;
81  bool operator!=(const MeshSharding &rhs) const;
82  bool equalSplitAndPartialAxes(const MeshSharding &rhs) const;
83  bool equalHaloAndShardSizes(const MeshSharding &rhs) const;
84  bool equalHaloSizes(const MeshSharding &rhs) const;
85  bool equalShardSizes(const MeshSharding &rhs) const;
86 };
87 
88 } // namespace mesh
89 } // namespace mlir
90 
91 #define GET_TYPEDEF_CLASSES
92 #include "mlir/Dialect/Mesh/IR/MeshTypes.h.inc"
93 
94 #define GET_OP_CLASSES
95 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
96 
97 namespace mlir {
98 namespace mesh {
99 
100 inline bool isReductionLoop(utils::IteratorType iType) {
101  return iType == utils::IteratorType::reduction;
102 }
103 
104 // Remove empty subarrays of `array` until a minimum lengh of one is reached.
105 template <typename T>
107  while (array.size() > 1 && array.back().empty())
108  array.pop_back();
109 }
110 
111 // Is the same tensor replicated on all processes.
112 inline bool isFullReplication(MeshSharding sharding) {
113  return sharding.getPartialAxes().empty() &&
114  llvm::all_of(sharding.getSplitAxes(), [](MeshAxesAttr axes) {
115  return axes.asArrayRef().empty();
116  });
117 }
118 
119 inline mesh::MeshOp
121  SymbolTableCollection &symbolTableCollection) {
122  if (!meshSymbol)
123  return nullptr;
124  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
125  op, meshSymbol);
126 }
127 
129  SymbolTableCollection &symbolTableCollection) {
130  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
131  assert(meshOp);
132  return meshOp;
133 }
134 
135 // Get the corresponding mesh op using the standard attribute nomenclature.
136 template <typename Op>
137 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
138  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
139 }
140 
141 template <>
143 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
144  return getMesh(
145  op.getOperation(),
146  cast<ShardingOp>(op.getSharding().getDefiningOp()).getMeshAttr(),
147  symbolTableCollection);
148 }
149 
150 // Get the number of processes that participate in each group
151 // induced by `meshAxes`.
152 template <typename MeshAxesRange, typename MeshShapeRange>
153 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
154  MeshShapeRange &&meshShape) {
155  int64_t res = 1;
156 
157  for (MeshAxis axis : meshAxes) {
158  auto axisSize = *(std::begin(meshShape) + axis);
159  if (ShapedType::isDynamic(axisSize)) {
160  return ShapedType::kDynamic;
161  }
162  res *= axisSize;
163  }
164 
165  return res;
166 }
167 
168 template <typename MeshAxesRange>
169 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
170  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
171  mesh.getShape());
172 }
173 
174 // Get the size of a sharded dimension.
175 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
176  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
177  return ShapedType::kDynamic;
178 
179  assert(dimSize % shardCount == 0);
180  return dimSize / shardCount;
181 }
182 
183 // Get the size of an unsharded dimension.
184 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
185  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
186  return ShapedType::kDynamic;
187 
188  return dimSize * shardCount;
189 }
190 
191 // Return the sharded shape `shape` according ot sharding `sharding`.
192 // The shape for the tensor on each device in the mesh.
193 // Example:
194 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
195 // result in a shape for each shard of ?x2x?.
196 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
197  MeshSharding sharding);
198 
199 // If ranked tensor type return its sharded counterpart.
200 //
201 // If not ranked tensor type return `type`.
202 // `sharding` in that case must be null.
203 Type shardType(Type type, MeshOp mesh, MeshSharding sharding);
204 
205 // Insert shard op if there is not one that already has the same sharding.
206 // Use newShardOp if it is not null. Otherwise create a new one.
207 // May insert resharding if required.
208 // Potentially updates newShardOp.
210  OpOperand &operand, OpBuilder &builder,
211  ShardOp &newShardOp);
213  OpBuilder &builder);
215  OpOperand &operand,
216  OpBuilder &builder);
217 
218 } // namespace mesh
219 } // namespace mlir
220 
221 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:243
This is a value defined by a result of an operation.
Definition: Value.h:433
This provides public APIs that all operations should have.
Operation * getOperation()
Inherit getOperation from OpState.
Definition: OpDefinition.h:111
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
bool equalSplitAndPartialAxes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:689
ArrayRef< int64_t > getStaticShardedDimsOffsets() const
Definition: MeshOps.h:70
bool equalHaloAndShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:714
::mlir::FlatSymbolRefAttr getMeshAttr() const
Definition: MeshOps.h:64
bool equalHaloSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:734
ArrayRef< MeshAxesAttr > getSplitAxes() const
Definition: MeshOps.h:66
bool operator!=(Value rhs) const
Definition: MeshOps.cpp:750
ReductionKind getPartialType() const
Definition: MeshOps.h:68
ArrayRef< Value > getDynamicShardedDimsOffsets() const
Definition: MeshOps.h:74
bool operator==(Value rhs) const
Definition: MeshOps.cpp:746
ArrayRef< MeshAxis > getPartialAxes() const
Definition: MeshOps.h:67
ArrayRef< Value > getDynamicHaloSizes() const
Definition: MeshOps.h:73
::llvm::StringRef getMesh() const
Definition: MeshOps.h:65
ArrayRef< int64_t > getStaticHaloSizes() const
Definition: MeshOps.h:69
MeshSharding(::mlir::FlatSymbolRefAttr mesh_=nullptr)
Definition: MeshOps.cpp:760
static MeshSharding get(::mlir::FlatSymbolRefAttr mesh_, ArrayRef< MeshAxesAttr > split_axes_, ArrayRef< MeshAxis > partial_axes_={}, ReductionKind partial_type_=ReductionKind::Sum, ArrayRef< int64_t > static_halo_sizes_={}, ArrayRef< int64_t > static_sharded_dims_offsets_={}, ArrayRef< Value > dynamic_halo_sizes_={}, ArrayRef< Value > dynamic_sharded_dims_offsets_={})
Definition: MeshOps.cpp:780
bool equalShardSizes(const MeshSharding &rhs) const
Definition: MeshOps.cpp:718
mesh::ReductionKind ReductionKind
mesh::MeshSharding MeshSharding
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:153
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:120
void maybeInsertSourceShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:329
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:184
Type shardType(Type type, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:270
void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpOperand &operand, OpBuilder &builder, ShardOp &newShardOp)
Definition: MeshOps.cpp:278
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshSharding sharding)
Definition: MeshOps.cpp:260
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:128
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:175
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:100
bool isFullReplication(MeshSharding sharding)
Definition: MeshOps.h:112
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:143
int16_t MeshAxis
Definition: MeshOps.h:26
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:106
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr