MLIR  20.0.0git
MeshOps.h
Go to the documentation of this file.
1 //===- MeshOps.h - Mesh Dialect Operations ----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_MESH_IR_MESHOPS_H
10 #define MLIR_DIALECT_MESH_IR_MESHOPS_H
11 
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/SymbolTable.h"
20 #include "llvm/Support/MathExtras.h"
21 
22 namespace mlir {
23 namespace mesh {
24 
25 using MeshAxis = int16_t;
27 
28 } // namespace mesh
29 } // namespace mlir
30 
31 #include "mlir/Dialect/Mesh/IR/MeshEnums.h.inc"
32 
33 #define GET_ATTRDEF_CLASSES
34 #include "mlir/Dialect/Mesh/IR/MeshAttributes.h.inc"
35 
36 #define GET_OP_CLASSES
37 #include "mlir/Dialect/Mesh/IR/MeshOps.h.inc"
38 
39 namespace mlir {
40 namespace mesh {
41 
42 inline bool isReductionLoop(utils::IteratorType iType) {
43  return iType == utils::IteratorType::reduction;
44 }
45 
46 template <typename T>
48  while (!array.empty() && array.back().empty())
49  array.pop_back();
50 }
51 
52 // Is the same tensor replicated on all processes.
54  return attr.getPartialAxes().empty() &&
55  llvm::all_of(attr.getSplitAxes(), [](MeshAxesAttr axes) {
56  return axes.asArrayRef().empty();
57  });
58 }
59 
60 inline mesh::MeshOp
62  SymbolTableCollection &symbolTableCollection) {
63  return symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
64  op, meshSymbol);
65 }
66 
68  SymbolTableCollection &symbolTableCollection) {
69  mesh::MeshOp meshOp = getMeshOrNull(op, meshSymbol, symbolTableCollection);
70  assert(meshOp);
71  return meshOp;
72 }
73 
74 // Get the corresponding mesh op using the standard attribute nomenclature.
75 template <typename Op>
76 mesh::MeshOp getMesh(Op op, SymbolTableCollection &symbolTableCollection) {
77  return getMesh(op.getOperation(), op.getMeshAttr(), symbolTableCollection);
78 }
79 
80 template <>
82 getMesh<ShardOp>(ShardOp op, SymbolTableCollection &symbolTableCollection) {
83  return getMesh(op.getOperation(), op.getShardAttr().getMesh(),
84  symbolTableCollection);
85 }
86 
87 // Get the number of processes that participate in each group
88 // induced by `meshAxes`.
89 template <typename MeshAxesRange, typename MeshShapeRange>
90 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes,
91  MeshShapeRange &&meshShape) {
92  int64_t res = 1;
93 
94  for (MeshAxis axis : meshAxes) {
95  auto axisSize = *(std::begin(meshShape) + axis);
96  if (ShapedType::isDynamic(axisSize)) {
97  return ShapedType::kDynamic;
98  }
99  res *= axisSize;
100  }
101 
102  return res;
103 }
104 
105 template <typename MeshAxesRange>
106 int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshOp mesh) {
107  return collectiveProcessGroupSize(std::forward<MeshAxesRange>(meshAxes),
108  mesh.getShape());
109 }
110 
111 // Get the size of a sharded dimension.
112 inline int64_t shardDimension(int64_t dimSize, int64_t shardCount) {
113  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
114  return ShapedType::kDynamic;
115 
116  assert(dimSize % shardCount == 0);
117  return dimSize / shardCount;
118 }
119 
120 // Get the size of an unsharded dimension.
121 inline int64_t gatherDimension(int64_t dimSize, int64_t shardCount) {
122  if (ShapedType::isDynamic(dimSize) || ShapedType::isDynamic(shardCount))
123  return ShapedType::kDynamic;
124 
125  return dimSize * shardCount;
126 }
127 
128 // Return the sharded shape `shape` according ot sharding `sharding`.
129 // The shape for the tensor on each device in the mesh.
130 // Example:
131 // On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1 would
132 // result in a shape for each shard of ?x2x?.
133 ShapedType shardShapedType(ShapedType shape, MeshOp mesh,
134  MeshShardingAttr sharding);
135 
136 // If ranked tensor type return its sharded counterpart.
137 //
138 // If not ranked tensor type return `type`.
139 // `sharding` in that case must be null.
140 Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding);
141 
142 // Insert shard op if there is not one that already has the same sharding.
143 // May insert resharding if required.
145  OpOperand &operand,
146  OpBuilder &builder);
148  OpResult result, OpBuilder &builder);
150  OpOperand &operand,
151  OpBuilder &builder);
152 
153 } // namespace mesh
154 } // namespace mlir
155 
156 #endif // MLIR_DIALECT_MESH_IR_MESHOPS_H
A symbol reference with a reference path containing a single element.
This class helps build Operations.
Definition: Builders.h:210
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
mesh::MeshShardingAttr MeshShardingAttr
int64_t collectiveProcessGroupSize(MeshAxesRange &&meshAxes, MeshShapeRange &&meshShape)
Definition: MeshOps.h:90
void maybeInsertTargetShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:182
mesh::MeshOp getMeshOrNull(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:61
int64_t gatherDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:121
mesh::MeshOp getMesh(Operation *op, FlatSymbolRefAttr meshSymbol, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:67
int64_t shardDimension(int64_t dimSize, int64_t shardCount)
Definition: MeshOps.h:112
bool isFullReplication(MeshShardingAttr attr)
Definition: MeshOps.h:53
bool isReductionLoop(utils::IteratorType iType)
Definition: MeshOps.h:42
mesh::MeshOp getMesh< ShardOp >(ShardOp op, SymbolTableCollection &symbolTableCollection)
Definition: MeshOps.h:82
int16_t MeshAxis
Definition: MeshOps.h:25
ShapedType shardShapedType(ShapedType shape, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:163
Type shardType(Type type, MeshOp mesh, MeshShardingAttr sharding)
Definition: MeshOps.cpp:172
void removeTrailingEmptySubArray(SmallVector< SmallVector< T >> &array)
Definition: MeshOps.h:47
void maybeInsertSourceShardingAnnotation(MeshShardingAttr sharding, OpOperand &operand, OpBuilder &builder)
Definition: MeshOps.cpp:222
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int16_t > DenseI16ArrayAttr